summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-04 19:24:24 +0100
committerVolpeon <git@volpeon.ink>2023-03-04 19:24:24 +0100
commitbc28ad0e0355916cb7e0b2df5ee0992f2e0b427c (patch)
tree88505e6fb13666ba459577935151aab43ee019d2
parentAdded Perlin noise to training (diff)
downloadtextual-inversion-diff-bc28ad0e0355916cb7e0b2df5ee0992f2e0b427c.tar.gz
textual-inversion-diff-bc28ad0e0355916cb7e0b2df5ee0992f2e0b427c.tar.bz2
textual-inversion-diff-bc28ad0e0355916cb7e0b2df5ee0992f2e0b427c.zip
More flexible pipeline wrt init noise
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py57
1 files changed, 44 insertions, 13 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 5f4fc38..f27be78 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -1,7 +1,7 @@
1import inspect 1import inspect
2import warnings 2import warnings
3import math 3import math
4from typing import List, Dict, Any, Optional, Union, Callable 4from typing import List, Dict, Any, Optional, Union, Callable, Literal
5 5
6import numpy as np 6import numpy as np
7import torch 7import torch
@@ -22,7 +22,7 @@ from diffusers import (
22 PNDMScheduler, 22 PNDMScheduler,
23) 23)
24from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 24from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
25from diffusers.utils import logging 25from diffusers.utils import logging, randn_tensor
26from transformers import CLIPTextModel, CLIPTokenizer 26from transformers import CLIPTextModel, CLIPTokenizer
27 27
28from models.clip.util import unify_input_ids, get_extended_embeddings 28from models.clip.util import unify_input_ids, get_extended_embeddings
@@ -312,7 +312,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
312 ).expand(batch_size, 3, width, height) 312 ).expand(batch_size, 3, width, height)
313 return (1.4 * noise).clamp(-1, 1) 313 return (1.4 * noise).clamp(-1, 1)
314 314
315 def prepare_latents(self, init_image, timestep, batch_size, dtype, device, generator=None): 315 def prepare_latents_from_image(self, init_image, timestep, batch_size, dtype, device, generator=None):
316 init_image = init_image.to(device=device, dtype=dtype) 316 init_image = init_image.to(device=device, dtype=dtype)
317 init_latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) 317 init_latents = self.vae.encode(init_image).latent_dist.sample(generator=generator)
318 init_latents = self.vae.config.scaling_factor * init_latents 318 init_latents = self.vae.config.scaling_factor * init_latents
@@ -334,6 +334,23 @@ class VlpnStableDiffusion(DiffusionPipeline):
334 334
335 return latents 335 return latents
336 336
337 def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
338 shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
339 if isinstance(generator, list) and len(generator) != batch_size:
340 raise ValueError(
341 f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
342 f" size of {batch_size}. Make sure the batch size matches the length of the generators."
343 )
344
345 if latents is None:
346 latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
347 else:
348 latents = latents.to(device)
349
350 # scale the initial noise by the standard deviation required by the scheduler
351 latents = latents * self.scheduler.init_noise_sigma
352 return latents
353
337 def prepare_extra_step_kwargs(self, generator, eta): 354 def prepare_extra_step_kwargs(self, generator, eta):
338 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 355 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
339 # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 356 # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -373,7 +390,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
373 sag_scale: float = 0.75, 390 sag_scale: float = 0.75,
374 eta: float = 0.0, 391 eta: float = 0.0,
375 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 392 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
376 image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, 393 image: Optional[Union[torch.FloatTensor, PIL.Image.Image, Literal["noise"]]] = None,
377 output_type: str = "pil", 394 output_type: str = "pil",
378 return_dict: bool = True, 395 return_dict: bool = True,
379 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 396 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@@ -443,8 +460,10 @@ class VlpnStableDiffusion(DiffusionPipeline):
443 # 2. Define call parameters 460 # 2. Define call parameters
444 batch_size = len(prompt) 461 batch_size = len(prompt)
445 device = self.execution_device 462 device = self.execution_device
463 num_channels_latents = self.unet.in_channels
446 do_classifier_free_guidance = guidance_scale > 1.0 464 do_classifier_free_guidance = guidance_scale > 1.0
447 do_self_attention_guidance = sag_scale > 0.0 465 do_self_attention_guidance = sag_scale > 0.0
466 prep_from_image = isinstance(image, PIL.Image.Image) or image == "noise"
448 467
449 # 3. Encode input prompt 468 # 3. Encode input prompt
450 prompt_embeds = self.encode_prompt( 469 prompt_embeds = self.encode_prompt(
@@ -458,7 +477,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
458 # 4. Prepare latent variables 477 # 4. Prepare latent variables
459 if isinstance(image, PIL.Image.Image): 478 if isinstance(image, PIL.Image.Image):
460 image = preprocess(image) 479 image = preprocess(image)
461 elif image is None: 480 elif image == "noise":
462 image = self.prepare_image( 481 image = self.prepare_image(
463 batch_size * num_images_per_prompt, 482 batch_size * num_images_per_prompt,
464 width, 483 width,
@@ -474,14 +493,26 @@ class VlpnStableDiffusion(DiffusionPipeline):
474 latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) 493 latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
475 494
476 # 6. Prepare latent variables 495 # 6. Prepare latent variables
477 latents = self.prepare_latents( 496 if prep_from_image:
478 image, 497 latents = self.prepare_latents_from_image(
479 latent_timestep, 498 image,
480 batch_size * num_images_per_prompt, 499 latent_timestep,
481 prompt_embeds.dtype, 500 batch_size * num_images_per_prompt,
482 device, 501 prompt_embeds.dtype,
483 generator 502 device,
484 ) 503 generator
504 )
505 else:
506 latents = self.prepare_latents(
507 batch_size,
508 num_channels_latents,
509 height,
510 width,
511 prompt_embeds.dtype,
512 device,
513 generator,
514 image
515 )
485 516
486 # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 517 # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
487 extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 518 extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)