summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--infer.py3
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py71
-rw-r--r--schedulers/scheduling_euler_a.py24
3 files changed, 38 insertions, 60 deletions
diff --git a/infer.py b/infer.py
index c40335c..f2c380f 100644
--- a/infer.py
+++ b/infer.py
@@ -28,7 +28,7 @@ default_cmds = {
28 "prompt": None, 28 "prompt": None,
29 "negative_prompt": None, 29 "negative_prompt": None,
30 "image": None, 30 "image": None,
31 "image_strength": .7, 31 "image_strength": .3,
32 "width": 512, 32 "width": 512,
33 "height": 512, 33 "height": 512,
34 "batch_size": 1, 34 "batch_size": 1,
@@ -225,6 +225,7 @@ def generate(output_dir, pipeline, args):
225 guidance_scale=args.guidance_scale, 225 guidance_scale=args.guidance_scale,
226 generator=generator, 226 generator=generator,
227 latents=init_image, 227 latents=init_image,
228 strength=args.image_strength,
228 ).images 229 ).images
229 230
230 for j, image in enumerate(images): 231 for j, image in enumerate(images):
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index a8ecedf..b4c85e9 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -178,35 +178,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
178 # set timesteps 178 # set timesteps
179 self.scheduler.set_timesteps(num_inference_steps) 179 self.scheduler.set_timesteps(num_inference_steps)
180 180
181 offset = self.scheduler.config.get("steps_offset", 0)
182
183 if latents is not None and isinstance(latents, PIL.Image.Image):
184 latents = preprocess(latents, width, height)
185 latent_dist = self.vae.encode(latents.to(self.device)).latent_dist
186 latents = latent_dist.sample(generator=generator)
187 latents = 0.18215 * latents
188
189 # expand init_latents for batch_size
190 latents = torch.cat([latents] * batch_size)
191
192 # get the original timestep using init_timestep
193 init_timestep = int(num_inference_steps * strength) + offset
194 init_timestep = min(init_timestep, num_inference_steps)
195
196 if isinstance(self.scheduler, LMSDiscreteScheduler):
197 timesteps = torch.tensor(
198 [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
199 )
200 else:
201 timesteps = self.scheduler.timesteps[-init_timestep]
202 timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
203
204 # add noise to latents using the timesteps
205 noise = torch.randn(latents.shape, generator=generator, device=self.device)
206 latents = self.scheduler.add_noise(latents, noise, timesteps)
207 else:
208 init_timestep = num_inference_steps + offset
209
210 # get prompt text embeddings 181 # get prompt text embeddings
211 text_inputs = self.tokenizer( 182 text_inputs = self.tokenizer(
212 prompt, 183 prompt,
@@ -243,6 +214,10 @@ class VlpnStableDiffusion(DiffusionPipeline):
243 # to avoid doing two forward passes 214 # to avoid doing two forward passes
244 text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 215 text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
245 216
217 offset = self.scheduler.config.get("steps_offset", 0)
218 init_timestep = num_inference_steps + offset
219 ensure_sigma = not isinstance(latents, PIL.Image.Image)
220
246 # get the initial random noise unless the user supplied it 221 # get the initial random noise unless the user supplied it
247 222
248 # Unlike in other pipelines, latents need to be generated in the target device 223 # Unlike in other pipelines, latents need to be generated in the target device
@@ -257,23 +232,48 @@ class VlpnStableDiffusion(DiffusionPipeline):
257 device=latents_device, 232 device=latents_device,
258 dtype=text_embeddings.dtype, 233 dtype=text_embeddings.dtype,
259 ) 234 )
235 elif isinstance(latents, PIL.Image.Image):
236 latents = preprocess(latents, width, height)
237 latent_dist = self.vae.encode(latents.to(self.device)).latent_dist
238 latents = latent_dist.sample(generator=generator)
239 latents = 0.18215 * latents
240
241 # expand init_latents for batch_size
242 latents = torch.cat([latents] * batch_size)
243
244 # get the original timestep using init_timestep
245 init_timestep = int(num_inference_steps * strength) + offset
246 init_timestep = min(init_timestep, num_inference_steps)
247
248 if isinstance(self.scheduler, LMSDiscreteScheduler):
249 timesteps = torch.tensor(
250 [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
251 )
252 else:
253 timesteps = self.scheduler.timesteps[-init_timestep]
254 timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
255
256 # add noise to latents using the timesteps
257 noise = torch.randn(latents.shape, generator=generator, device=self.device)
258 latents = self.scheduler.add_noise(latents, noise, timesteps)
260 else: 259 else:
261 if latents.shape != latents_shape: 260 if latents.shape != latents_shape:
262 raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") 261 raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
263 latents = latents.to(self.device) 262 latents = latents.to(self.device)
264 263
264 # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
265 if ensure_sigma:
266 if isinstance(self.scheduler, LMSDiscreteScheduler):
267 latents = latents * self.scheduler.sigmas[0]
268 elif isinstance(self.scheduler, EulerAScheduler):
269 latents = latents * self.scheduler.sigmas[0]
270
265 t_start = max(num_inference_steps - init_timestep + offset, 0) 271 t_start = max(num_inference_steps - init_timestep + offset, 0)
266 272
267 # Some schedulers like PNDM have timesteps as arrays 273 # Some schedulers like PNDM have timesteps as arrays
268 # It's more optimzed to move all timesteps to correct device beforehand 274 # It's more optimzed to move all timesteps to correct device beforehand
269 timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) 275 timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device)
270 276
271 # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
272 if isinstance(self.scheduler, LMSDiscreteScheduler):
273 latents = latents * self.scheduler.sigmas[0]
274 elif isinstance(self.scheduler, EulerAScheduler):
275 latents = latents * self.scheduler.sigmas[0]
276
277 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 277 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
278 # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 278 # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
279 # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 279 # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
@@ -292,6 +292,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
292 292
293 # expand the latents if we are doing classifier free guidance 293 # expand the latents if we are doing classifier free guidance
294 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 294 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
295
295 if isinstance(self.scheduler, LMSDiscreteScheduler): 296 if isinstance(self.scheduler, LMSDiscreteScheduler):
296 sigma = self.scheduler.sigmas[t_index] 297 sigma = self.scheduler.sigmas[t_index]
297 # the model input needs to be scaled to match the continuous ODE formulation in K-LMS 298 # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py
index 1b1c9cf..a2d0e9f 100644
--- a/schedulers/scheduling_euler_a.py
+++ b/schedulers/scheduling_euler_a.py
@@ -191,27 +191,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
191 self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) 191 self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device)
192 self.timesteps = np.arange(0, self.num_inference_steps) 192 self.timesteps = np.arange(0, self.num_inference_steps)
193 193
194 def add_noise_to_input(
195 self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None
196 ) -> Tuple[torch.FloatTensor, float]:
197 """
198 Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
199 higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
200
201 TODO Args:
202 """
203 if self.config.s_min <= sigma <= self.config.s_max:
204 gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1)
205 else:
206 gamma = 0
207
208 # sample eps ~ N(0, S_noise^2 * I)
209 eps = self.config.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device)
210 sigma_hat = sigma + gamma * sigma
211 sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
212
213 return sample_hat, sigma_hat
214
215 def step( 194 def step(
216 self, 195 self,
217 model_output: torch.FloatTensor, 196 model_output: torch.FloatTensor,
@@ -219,9 +198,6 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
219 timestep_prev: int, 198 timestep_prev: int,
220 sample: torch.FloatTensor, 199 sample: torch.FloatTensor,
221 generator: None, 200 generator: None,
222 # ,sigma_hat: float,
223 # sigma_prev: float,
224 # sample_hat: torch.FloatTensor,
225 return_dict: bool = True, 201 return_dict: bool = True,
226 ) -> Union[SchedulerOutput, Tuple]: 202 ) -> Union[SchedulerOutput, Tuple]:
227 """ 203 """