From 220c842d22f282544e4d12d277a40f39f85d3c35 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 4 Mar 2023 15:08:51 +0100 Subject: Added Perlin noise to training --- pipelines/stable_diffusion/vlpn_stable_diffusion.py | 4 ++-- training/functional.py | 17 +++++++++++++++++ util/noise.py | 15 +++++++++------ 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index f02dd72..5f4fc38 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -22,7 +22,7 @@ from diffusers import ( PNDMScheduler, ) from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput -from diffusers.utils import logging, randn_tensor +from diffusers.utils import logging from transformers import CLIPTextModel, CLIPTokenizer from models.clip.util import unify_input_ids, get_extended_embeddings @@ -308,7 +308,7 @@ class VlpnStableDiffusion(DiffusionPipeline): def prepare_image(self, batch_size, width, height, dtype, device, generator=None): noise = perlin_noise( - batch_size, width, height, res=1, octaves=4, generator=generator, dtype=dtype, device=device + batch_size, 1, width, height, res=1, octaves=4, generator=generator, dtype=dtype, device=device ).expand(batch_size, 3, width, height) return (1.4 * noise).clamp(-1, 1) diff --git a/training/functional.py b/training/functional.py index 1c38635..db46766 100644 --- a/training/functional.py +++ b/training/functional.py @@ -23,6 +23,7 @@ from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embe from models.clip.util import get_extended_embeddings from models.clip.tokenizer import MultiCLIPTokenizer from training.util import AverageMeter +from util.noise import perlin_noise def const(result=None): @@ -253,6 +254,7 @@ def loss_step( text_encoder: CLIPTextModel, with_prior_preservation: bool, prior_loss_weight: float, + perlin_strength: float, seed: int, step: int, batch: dict[str, Any], @@ -275,6 +277,19 @@ def loss_step( generator=generator ) + if perlin_strength != 0: + noise += perlin_strength * perlin_noise( + latents.shape[0], + latents.shape[1], + latents.shape[2], + latents.shape[3], + res=1, + octaves=4, + dtype=latents.dtype, + device=latents.device, + generator=generator + ) + # Sample a random timestep for each image timesteps = torch.randint( 0, @@ -559,6 +574,7 @@ def train( global_step_offset: int = 0, with_prior_preservation: bool = False, prior_loss_weight: float = 1.0, + perlin_strength: float = 0.1, **kwargs, ): text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( @@ -593,6 +609,7 @@ def train( text_encoder, with_prior_preservation, prior_loss_weight, + perlin_strength, seed, ) diff --git a/util/noise.py b/util/noise.py index 38ab172..3c4f82d 100644 --- a/util/noise.py +++ b/util/noise.py @@ -1,4 +1,5 @@ import math + import torch # 2D Perlin noise in PyTorch https://gist.github.com/vadimkantorov/ac1b097753f217c5c11bc2ff396e0a57 @@ -47,11 +48,13 @@ def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5, dtype=None, d return noise -def perlin_noise(batch_size: int, width: int, height: int, res=8, octaves=1, dtype=None, device=None, generator=None): +def perlin_noise(batch_size: int, channels: int, width: int, height: int, res=8, octaves=1, dtype=None, device=None, generator=None): return torch.stack([ - rand_perlin_2d_octaves( - (width, height), (res, res), octaves, dtype=dtype, device=device, generator=generator - ).unsqueeze(0) - for _ - in range(batch_size) + torch.stack([ + rand_perlin_2d_octaves( + (width, height), (res, res), octaves, dtype=dtype, device=device, generator=generator + ) + for _ in range(channels) + ]) + for _ in range(batch_size) ]) -- cgit v1.2.3-70-g09d2