diff options
Diffstat (limited to 'pipelines/stable_diffusion')
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 23 |
1 files changed, 9 insertions, 14 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 2251848..a6b31d8 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -24,7 +24,9 @@ from diffusers import ( | |||
24 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 24 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
25 | from diffusers.utils import logging, randn_tensor | 25 | from diffusers.utils import logging, randn_tensor |
26 | from transformers import CLIPTextModel, CLIPTokenizer | 26 | from transformers import CLIPTextModel, CLIPTokenizer |
27 | |||
27 | from models.clip.util import unify_input_ids, get_extended_embeddings | 28 | from models.clip.util import unify_input_ids, get_extended_embeddings |
29 | from util.noise import perlin_noise | ||
28 | 30 | ||
29 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 31 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
30 | 32 | ||
@@ -304,23 +306,18 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
304 | 306 | ||
305 | return timesteps, num_inference_steps - t_start | 307 | return timesteps, num_inference_steps - t_start |
306 | 308 | ||
307 | def prepare_image(self, batch_size, width, height, max_offset, dtype, device, generator=None): | 309 | def prepare_image(self, batch_size, width, height, dtype, device, generator=None): |
308 | offset = (max_offset * (2 * torch.rand( | 310 | max = 0.4 |
311 | offset = max * (2 * torch.rand( | ||
309 | (batch_size, 1, 1, 1), | 312 | (batch_size, 1, 1, 1), |
310 | dtype=dtype, | 313 | dtype=dtype, |
311 | device=device, | 314 | device=device, |
312 | generator=generator | 315 | generator=generator |
313 | ) - 1)).expand(batch_size, 1, 2, 2) | 316 | ) - 1) |
314 | image = F.interpolate( | 317 | noise = perlin_noise( |
315 | torch.normal( | 318 | batch_size, width, height, res=3, octaves=3, generator=generator, dtype=dtype, device=device |
316 | mean=offset, | ||
317 | std=0.3, | ||
318 | generator=generator | ||
319 | ).clamp(-1, 1), | ||
320 | size=(width, height), | ||
321 | mode="bicubic" | ||
322 | ).expand(batch_size, 3, width, height) | 319 | ).expand(batch_size, 3, width, height) |
323 | return image | 320 | return ((1 + max) * noise + max * offset).clamp(-1, 1) |
324 | 321 | ||
325 | def prepare_latents(self, init_image, timestep, batch_size, dtype, device, generator=None): | 322 | def prepare_latents(self, init_image, timestep, batch_size, dtype, device, generator=None): |
326 | init_image = init_image.to(device=device, dtype=dtype) | 323 | init_image = init_image.to(device=device, dtype=dtype) |
@@ -384,7 +381,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
384 | eta: float = 0.0, | 381 | eta: float = 0.0, |
385 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 382 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
386 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, | 383 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
387 | max_init_offset: float = 0.7, | ||
388 | output_type: str = "pil", | 384 | output_type: str = "pil", |
389 | return_dict: bool = True, | 385 | return_dict: bool = True, |
390 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 386 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
@@ -474,7 +470,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
474 | batch_size * num_images_per_prompt, | 470 | batch_size * num_images_per_prompt, |
475 | width, | 471 | width, |
476 | height, | 472 | height, |
477 | max_init_offset, | ||
478 | prompt_embeds.dtype, | 473 | prompt_embeds.dtype, |
479 | device, | 474 | device, |
480 | generator | 475 | generator |