diff options
author | Volpeon <git@volpeon.ink> | 2023-03-03 22:09:24 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-03 22:09:24 +0100 |
commit | 220806dbd21da3ba83c14096225c31824dfe81df (patch) | |
tree | a201272876bfe894f9d504d1582ac022add4b205 /pipelines | |
parent | Implemented different noise offset (diff) | |
download | textual-inversion-diff-220806dbd21da3ba83c14096225c31824dfe81df.tar.gz textual-inversion-diff-220806dbd21da3ba83c14096225c31824dfe81df.tar.bz2 textual-inversion-diff-220806dbd21da3ba83c14096225c31824dfe81df.zip |
Removed offset noise from training, added init offset to pipeline
Diffstat (limited to 'pipelines')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 95 |
1 files changed, 39 insertions, 56 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index cb09fe1..c4f7401 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -293,53 +293,39 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
293 | 293 | ||
294 | return prompt_embeds | 294 | return prompt_embeds |
295 | 295 | ||
296 | def get_timesteps(self, latents_are_image, num_inference_steps, strength, device): | 296 | def get_timesteps(self, num_inference_steps, strength, device): |
297 | if latents_are_image: | 297 | # get the original timestep using init_timestep |
298 | # get the original timestep using init_timestep | 298 | offset = self.scheduler.config.get("steps_offset", 0) |
299 | offset = self.scheduler.config.get("steps_offset", 0) | 299 | init_timestep = int(num_inference_steps * strength) + offset |
300 | init_timestep = int(num_inference_steps * strength) + offset | 300 | init_timestep = min(init_timestep, num_inference_steps) |
301 | init_timestep = min(init_timestep, num_inference_steps) | 301 | |
302 | 302 | t_start = max(num_inference_steps - init_timestep + offset, 0) | |
303 | t_start = max(num_inference_steps - init_timestep + offset, 0) | 303 | timesteps = self.scheduler.timesteps[t_start:] |
304 | timesteps = self.scheduler.timesteps[t_start:] | ||
305 | else: | ||
306 | timesteps = self.scheduler.timesteps | ||
307 | 304 | ||
308 | timesteps = timesteps.to(device) | 305 | timesteps = timesteps.to(device) |
309 | 306 | ||
310 | return timesteps | 307 | return timesteps |
311 | 308 | ||
312 | def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): | 309 | def prepare_image(self, batch_size, width, height, dtype, device, generator=None): |
313 | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) | 310 | return torch.randn( |
314 | 311 | (batch_size, 1, 1, 1), | |
315 | if isinstance(generator, list) and len(generator) != batch_size: | 312 | dtype=dtype, |
316 | raise ValueError( | 313 | device=device, |
317 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | 314 | generator=generator |
318 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." | 315 | ).expand(batch_size, 3, width, height) |
319 | ) | ||
320 | 316 | ||
321 | if latents is None: | 317 | def prepare_latents(self, init_image, timestep, batch_size, dtype, device, generator=None): |
322 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | ||
323 | else: | ||
324 | latents = latents.to(device=device, dtype=dtype) | ||
325 | |||
326 | # scale the initial noise by the standard deviation required by the scheduler | ||
327 | latents = latents * self.scheduler.init_noise_sigma | ||
328 | |||
329 | return latents | ||
330 | |||
331 | def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None): | ||
332 | init_image = init_image.to(device=device, dtype=dtype) | 318 | init_image = init_image.to(device=device, dtype=dtype) |
333 | init_latent_dist = self.vae.encode(init_image).latent_dist | 319 | init_latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) |
334 | init_latents = init_latent_dist.sample(generator=generator) | 320 | init_latents = self.vae.config.scaling_factor * init_latents |
335 | init_latents = 0.18215 * init_latents | ||
336 | 321 | ||
337 | if batch_size > init_latents.shape[0]: | 322 | if batch_size % init_latents.shape[0] != 0: |
338 | raise ValueError( | 323 | raise ValueError( |
339 | f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." | 324 | f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." |
340 | ) | 325 | ) |
341 | else: | 326 | else: |
342 | init_latents = torch.cat([init_latents] * batch_size, dim=0) | 327 | batch_multiplier = batch_size // init_latents.shape[0] |
328 | init_latents = torch.cat([init_latents] * batch_multiplier, dim=0) | ||
343 | 329 | ||
344 | # add noise to latents using the timesteps | 330 | # add noise to latents using the timesteps |
345 | noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) | 331 | noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) |
@@ -368,7 +354,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
368 | return extra_step_kwargs | 354 | return extra_step_kwargs |
369 | 355 | ||
370 | def decode_latents(self, latents): | 356 | def decode_latents(self, latents): |
371 | latents = 1 / 0.18215 * latents | 357 | latents = 1 / self.vae.config.scaling_factor * latents |
372 | image = self.vae.decode(latents).sample | 358 | image = self.vae.decode(latents).sample |
373 | image = (image / 2 + 0.5).clamp(0, 1) | 359 | image = (image / 2 + 0.5).clamp(0, 1) |
374 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | 360 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 |
@@ -381,7 +367,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
381 | prompt: Union[str, List[str], List[int], List[List[int]]], | 367 | prompt: Union[str, List[str], List[int], List[List[int]]], |
382 | negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None, | 368 | negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None, |
383 | num_images_per_prompt: int = 1, | 369 | num_images_per_prompt: int = 1, |
384 | strength: float = 0.8, | 370 | strength: float = 1.0, |
385 | height: Optional[int] = None, | 371 | height: Optional[int] = None, |
386 | width: Optional[int] = None, | 372 | width: Optional[int] = None, |
387 | num_inference_steps: int = 50, | 373 | num_inference_steps: int = 50, |
@@ -461,7 +447,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
461 | device = self.execution_device | 447 | device = self.execution_device |
462 | do_classifier_free_guidance = guidance_scale > 1.0 | 448 | do_classifier_free_guidance = guidance_scale > 1.0 |
463 | do_self_attention_guidance = sag_scale > 0.0 | 449 | do_self_attention_guidance = sag_scale > 0.0 |
464 | latents_are_image = isinstance(image, PIL.Image.Image) | ||
465 | 450 | ||
466 | # 3. Encode input prompt | 451 | # 3. Encode input prompt |
467 | prompt_embeds = self.encode_prompt( | 452 | prompt_embeds = self.encode_prompt( |
@@ -474,33 +459,31 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
474 | 459 | ||
475 | # 4. Prepare timesteps | 460 | # 4. Prepare timesteps |
476 | self.scheduler.set_timesteps(num_inference_steps, device=device) | 461 | self.scheduler.set_timesteps(num_inference_steps, device=device) |
477 | timesteps = self.get_timesteps(latents_are_image, num_inference_steps, strength, device) | 462 | timesteps = self.get_timesteps(num_inference_steps, strength, device) |
478 | 463 | ||
479 | # 5. Prepare latent variables | 464 | # 5. Prepare latent variables |
480 | num_channels_latents = self.unet.in_channels | 465 | if isinstance(image, PIL.Image.Image): |
481 | if latents_are_image: | ||
482 | image = preprocess(image) | 466 | image = preprocess(image) |
483 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) | 467 | elif image is None: |
484 | latents = self.prepare_latents_from_image( | 468 | image = self.prepare_image( |
485 | image, | ||
486 | latent_timestep, | ||
487 | batch_size * num_images_per_prompt, | 469 | batch_size * num_images_per_prompt, |
488 | prompt_embeds.dtype, | ||
489 | device, | ||
490 | generator | ||
491 | ) | ||
492 | else: | ||
493 | latents = self.prepare_latents( | ||
494 | batch_size * num_images_per_prompt, | ||
495 | num_channels_latents, | ||
496 | height, | ||
497 | width, | 470 | width, |
471 | height, | ||
498 | prompt_embeds.dtype, | 472 | prompt_embeds.dtype, |
499 | device, | 473 | device, |
500 | generator, | 474 | generator |
501 | image, | ||
502 | ) | 475 | ) |
503 | 476 | ||
477 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) | ||
478 | latents = self.prepare_latents( | ||
479 | image, | ||
480 | latent_timestep, | ||
481 | batch_size * num_images_per_prompt, | ||
482 | prompt_embeds.dtype, | ||
483 | device, | ||
484 | generator | ||
485 | ) | ||
486 | |||
504 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline | 487 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline |
505 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | 488 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
506 | 489 | ||