summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py4
-rw-r--r--training/functional.py17
-rw-r--r--util/noise.py15
3 files changed, 28 insertions, 8 deletions
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index f02dd72..5f4fc38 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -22,7 +22,7 @@ from diffusers import (
22 PNDMScheduler, 22 PNDMScheduler,
23) 23)
24from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 24from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
25from diffusers.utils import logging, randn_tensor 25from diffusers.utils import logging
26from transformers import CLIPTextModel, CLIPTokenizer 26from transformers import CLIPTextModel, CLIPTokenizer
27 27
28from models.clip.util import unify_input_ids, get_extended_embeddings 28from models.clip.util import unify_input_ids, get_extended_embeddings
@@ -308,7 +308,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
308 308
309 def prepare_image(self, batch_size, width, height, dtype, device, generator=None): 309 def prepare_image(self, batch_size, width, height, dtype, device, generator=None):
310 noise = perlin_noise( 310 noise = perlin_noise(
311 batch_size, width, height, res=1, octaves=4, generator=generator, dtype=dtype, device=device 311 batch_size, 1, width, height, res=1, octaves=4, generator=generator, dtype=dtype, device=device
312 ).expand(batch_size, 3, width, height) 312 ).expand(batch_size, 3, width, height)
313 return (1.4 * noise).clamp(-1, 1) 313 return (1.4 * noise).clamp(-1, 1)
314 314
diff --git a/training/functional.py b/training/functional.py
index 1c38635..db46766 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -23,6 +23,7 @@ from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embe
23from models.clip.util import get_extended_embeddings 23from models.clip.util import get_extended_embeddings
24from models.clip.tokenizer import MultiCLIPTokenizer 24from models.clip.tokenizer import MultiCLIPTokenizer
25from training.util import AverageMeter 25from training.util import AverageMeter
26from util.noise import perlin_noise
26 27
27 28
28def const(result=None): 29def const(result=None):
@@ -253,6 +254,7 @@ def loss_step(
253 text_encoder: CLIPTextModel, 254 text_encoder: CLIPTextModel,
254 with_prior_preservation: bool, 255 with_prior_preservation: bool,
255 prior_loss_weight: float, 256 prior_loss_weight: float,
257 perlin_strength: float,
256 seed: int, 258 seed: int,
257 step: int, 259 step: int,
258 batch: dict[str, Any], 260 batch: dict[str, Any],
@@ -275,6 +277,19 @@ def loss_step(
275 generator=generator 277 generator=generator
276 ) 278 )
277 279
280 if perlin_strength != 0:
281 noise += perlin_strength * perlin_noise(
282 latents.shape[0],
283 latents.shape[1],
284 latents.shape[2],
285 latents.shape[3],
286 res=1,
287 octaves=4,
288 dtype=latents.dtype,
289 device=latents.device,
290 generator=generator
291 )
292
278 # Sample a random timestep for each image 293 # Sample a random timestep for each image
279 timesteps = torch.randint( 294 timesteps = torch.randint(
280 0, 295 0,
@@ -559,6 +574,7 @@ def train(
559 global_step_offset: int = 0, 574 global_step_offset: int = 0,
560 with_prior_preservation: bool = False, 575 with_prior_preservation: bool = False,
561 prior_loss_weight: float = 1.0, 576 prior_loss_weight: float = 1.0,
577 perlin_strength: float = 0.1,
562 **kwargs, 578 **kwargs,
563): 579):
564 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( 580 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare(
@@ -593,6 +609,7 @@ def train(
593 text_encoder, 609 text_encoder,
594 with_prior_preservation, 610 with_prior_preservation,
595 prior_loss_weight, 611 prior_loss_weight,
612 perlin_strength,
596 seed, 613 seed,
597 ) 614 )
598 615
diff --git a/util/noise.py b/util/noise.py
index 38ab172..3c4f82d 100644
--- a/util/noise.py
+++ b/util/noise.py
@@ -1,4 +1,5 @@
1import math 1import math
2
2import torch 3import torch
3 4
4# 2D Perlin noise in PyTorch https://gist.github.com/vadimkantorov/ac1b097753f217c5c11bc2ff396e0a57 5# 2D Perlin noise in PyTorch https://gist.github.com/vadimkantorov/ac1b097753f217c5c11bc2ff396e0a57
@@ -47,11 +48,13 @@ def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5, dtype=None, d
47 return noise 48 return noise
48 49
49 50
50def perlin_noise(batch_size: int, width: int, height: int, res=8, octaves=1, dtype=None, device=None, generator=None): 51def perlin_noise(batch_size: int, channels: int, width: int, height: int, res=8, octaves=1, dtype=None, device=None, generator=None):
51 return torch.stack([ 52 return torch.stack([
52 rand_perlin_2d_octaves( 53 torch.stack([
53 (width, height), (res, res), octaves, dtype=dtype, device=device, generator=generator 54 rand_perlin_2d_octaves(
54 ).unsqueeze(0) 55 (width, height), (res, res), octaves, dtype=dtype, device=device, generator=generator
55 for _ 56 )
56 in range(batch_size) 57 for _ in range(channels)
58 ])
59 for _ in range(batch_size)
57 ]) 60 ])