diff options
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 4 | ||||
-rw-r--r-- | training/functional.py | 17 | ||||
-rw-r--r-- | util/noise.py | 15 |
3 files changed, 28 insertions, 8 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index f02dd72..5f4fc38 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -22,7 +22,7 @@ from diffusers import ( | |||
22 | PNDMScheduler, | 22 | PNDMScheduler, |
23 | ) | 23 | ) |
24 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 24 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
25 | from diffusers.utils import logging, randn_tensor | 25 | from diffusers.utils import logging |
26 | from transformers import CLIPTextModel, CLIPTokenizer | 26 | from transformers import CLIPTextModel, CLIPTokenizer |
27 | 27 | ||
28 | from models.clip.util import unify_input_ids, get_extended_embeddings | 28 | from models.clip.util import unify_input_ids, get_extended_embeddings |
@@ -308,7 +308,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
308 | 308 | ||
309 | def prepare_image(self, batch_size, width, height, dtype, device, generator=None): | 309 | def prepare_image(self, batch_size, width, height, dtype, device, generator=None): |
310 | noise = perlin_noise( | 310 | noise = perlin_noise( |
311 | batch_size, width, height, res=1, octaves=4, generator=generator, dtype=dtype, device=device | 311 | batch_size, 1, width, height, res=1, octaves=4, generator=generator, dtype=dtype, device=device |
312 | ).expand(batch_size, 3, width, height) | 312 | ).expand(batch_size, 3, width, height) |
313 | return (1.4 * noise).clamp(-1, 1) | 313 | return (1.4 * noise).clamp(-1, 1) |
314 | 314 | ||
diff --git a/training/functional.py b/training/functional.py index 1c38635..db46766 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -23,6 +23,7 @@ from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embe | |||
23 | from models.clip.util import get_extended_embeddings | 23 | from models.clip.util import get_extended_embeddings |
24 | from models.clip.tokenizer import MultiCLIPTokenizer | 24 | from models.clip.tokenizer import MultiCLIPTokenizer |
25 | from training.util import AverageMeter | 25 | from training.util import AverageMeter |
26 | from util.noise import perlin_noise | ||
26 | 27 | ||
27 | 28 | ||
28 | def const(result=None): | 29 | def const(result=None): |
@@ -253,6 +254,7 @@ def loss_step( | |||
253 | text_encoder: CLIPTextModel, | 254 | text_encoder: CLIPTextModel, |
254 | with_prior_preservation: bool, | 255 | with_prior_preservation: bool, |
255 | prior_loss_weight: float, | 256 | prior_loss_weight: float, |
257 | perlin_strength: float, | ||
256 | seed: int, | 258 | seed: int, |
257 | step: int, | 259 | step: int, |
258 | batch: dict[str, Any], | 260 | batch: dict[str, Any], |
@@ -275,6 +277,19 @@ def loss_step( | |||
275 | generator=generator | 277 | generator=generator |
276 | ) | 278 | ) |
277 | 279 | ||
280 | if perlin_strength != 0: | ||
281 | noise += perlin_strength * perlin_noise( | ||
282 | latents.shape[0], | ||
283 | latents.shape[1], | ||
284 | latents.shape[2], | ||
285 | latents.shape[3], | ||
286 | res=1, | ||
287 | octaves=4, | ||
288 | dtype=latents.dtype, | ||
289 | device=latents.device, | ||
290 | generator=generator | ||
291 | ) | ||
292 | |||
278 | # Sample a random timestep for each image | 293 | # Sample a random timestep for each image |
279 | timesteps = torch.randint( | 294 | timesteps = torch.randint( |
280 | 0, | 295 | 0, |
@@ -559,6 +574,7 @@ def train( | |||
559 | global_step_offset: int = 0, | 574 | global_step_offset: int = 0, |
560 | with_prior_preservation: bool = False, | 575 | with_prior_preservation: bool = False, |
561 | prior_loss_weight: float = 1.0, | 576 | prior_loss_weight: float = 1.0, |
577 | perlin_strength: float = 0.1, | ||
562 | **kwargs, | 578 | **kwargs, |
563 | ): | 579 | ): |
564 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( | 580 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( |
@@ -593,6 +609,7 @@ def train( | |||
593 | text_encoder, | 609 | text_encoder, |
594 | with_prior_preservation, | 610 | with_prior_preservation, |
595 | prior_loss_weight, | 611 | prior_loss_weight, |
612 | perlin_strength, | ||
596 | seed, | 613 | seed, |
597 | ) | 614 | ) |
598 | 615 | ||
diff --git a/util/noise.py b/util/noise.py index 38ab172..3c4f82d 100644 --- a/util/noise.py +++ b/util/noise.py | |||
@@ -1,4 +1,5 @@ | |||
1 | import math | 1 | import math |
2 | |||
2 | import torch | 3 | import torch |
3 | 4 | ||
4 | # 2D Perlin noise in PyTorch https://gist.github.com/vadimkantorov/ac1b097753f217c5c11bc2ff396e0a57 | 5 | # 2D Perlin noise in PyTorch https://gist.github.com/vadimkantorov/ac1b097753f217c5c11bc2ff396e0a57 |
@@ -47,11 +48,13 @@ def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5, dtype=None, d | |||
47 | return noise | 48 | return noise |
48 | 49 | ||
49 | 50 | ||
50 | def perlin_noise(batch_size: int, width: int, height: int, res=8, octaves=1, dtype=None, device=None, generator=None): | 51 | def perlin_noise(batch_size: int, channels: int, width: int, height: int, res=8, octaves=1, dtype=None, device=None, generator=None): |
51 | return torch.stack([ | 52 | return torch.stack([ |
52 | rand_perlin_2d_octaves( | 53 | torch.stack([ |
53 | (width, height), (res, res), octaves, dtype=dtype, device=device, generator=generator | 54 | rand_perlin_2d_octaves( |
54 | ).unsqueeze(0) | 55 | (width, height), (res, res), octaves, dtype=dtype, device=device, generator=generator |
55 | for _ | 56 | ) |
56 | in range(batch_size) | 57 | for _ in range(channels) |
58 | ]) | ||
59 | for _ in range(batch_size) | ||
57 | ]) | 60 | ]) |