From 220c842d22f282544e4d12d277a40f39f85d3c35 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 4 Mar 2023 15:08:51 +0100 Subject: Added Perlin noise to training --- training/functional.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 1c38635..db46766 100644 --- a/training/functional.py +++ b/training/functional.py @@ -23,6 +23,7 @@ from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embe from models.clip.util import get_extended_embeddings from models.clip.tokenizer import MultiCLIPTokenizer from training.util import AverageMeter +from util.noise import perlin_noise def const(result=None): @@ -253,6 +254,7 @@ def loss_step( text_encoder: CLIPTextModel, with_prior_preservation: bool, prior_loss_weight: float, + perlin_strength: float, seed: int, step: int, batch: dict[str, Any], @@ -275,6 +277,19 @@ def loss_step( generator=generator ) + if perlin_strength != 0: + noise += perlin_strength * perlin_noise( + latents.shape[0], + latents.shape[1], + latents.shape[2], + latents.shape[3], + res=1, + octaves=4, + dtype=latents.dtype, + device=latents.device, + generator=generator + ) + # Sample a random timestep for each image timesteps = torch.randint( 0, @@ -559,6 +574,7 @@ def train( global_step_offset: int = 0, with_prior_preservation: bool = False, prior_loss_weight: float = 1.0, + perlin_strength: float = 0.1, **kwargs, ): text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( @@ -593,6 +609,7 @@ def train( text_encoder, with_prior_preservation, prior_loss_weight, + perlin_strength, seed, ) -- cgit v1.2.3-70-g09d2