summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-03 23:37:40 +0100
committerVolpeon <git@volpeon.ink>2023-03-03 23:37:40 +0100
commit55fc031aebf48f22c9e646eb4d72246bfdbc5068 (patch)
tree318c1eb5351fd96a3b408d7baf0d5ee6adcefa69
parentRemoved offset noise from training, added init offset to pipeline (diff)
downloadtextual-inversion-diff-55fc031aebf48f22c9e646eb4d72246bfdbc5068.tar.gz
textual-inversion-diff-55fc031aebf48f22c9e646eb4d72246bfdbc5068.tar.bz2
textual-inversion-diff-55fc031aebf48f22c9e646eb4d72246bfdbc5068.zip
Changed init noise algorithm
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py14
1 files changed, 11 insertions, 3 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index c4f7401..242be29 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -306,13 +306,19 @@ class VlpnStableDiffusion(DiffusionPipeline):
306 306
307 return timesteps 307 return timesteps
308 308
309 def prepare_image(self, batch_size, width, height, dtype, device, generator=None): 309 def prepare_image(self, batch_size, width, height, max_offset, dtype, device, generator=None):
310 return torch.randn( 310 offset = (max_offset * (2 * torch.rand(
311 (batch_size, 1, 1, 1), 311 (batch_size, 1, 1, 1),
312 dtype=dtype, 312 dtype=dtype,
313 device=device, 313 device=device,
314 generator=generator 314 generator=generator
315 ).expand(batch_size, 3, width, height) 315 ) - 1)).expand(batch_size, 3, width, height)
316 image = (.1 * torch.normal(
317 mean=offset,
318 std=1,
319 generator=generator
320 )).clamp(-1, 1)
321 return image
316 322
317 def prepare_latents(self, init_image, timestep, batch_size, dtype, device, generator=None): 323 def prepare_latents(self, init_image, timestep, batch_size, dtype, device, generator=None):
318 init_image = init_image.to(device=device, dtype=dtype) 324 init_image = init_image.to(device=device, dtype=dtype)
@@ -376,6 +382,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
376 eta: float = 0.0, 382 eta: float = 0.0,
377 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 383 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
378 image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, 384 image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
385 max_image_offset: float = 1.0,
379 output_type: str = "pil", 386 output_type: str = "pil",
380 return_dict: bool = True, 387 return_dict: bool = True,
381 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 388 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@@ -469,6 +476,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
469 batch_size * num_images_per_prompt, 476 batch_size * num_images_per_prompt,
470 width, 477 width,
471 height, 478 height,
479 max_image_offset,
472 prompt_embeds.dtype, 480 prompt_embeds.dtype,
473 device, 481 device,
474 generator 482 generator