From d9bb4a0d43276c8e120866af044fcf3566930859 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 23 Mar 2023 22:15:17 +0100 Subject: Bring back Perlin offset noise --- training/functional.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 015fe5e..a5b339d 100644 --- a/training/functional.py +++ b/training/functional.py @@ -278,10 +278,11 @@ def loss_step( with_prior_preservation: bool, prior_loss_weight: float, seed: int, + perlin_strength: float, step: int, batch: dict[str, Any], eval: bool = False, - min_snr_gamma: int = 5 + min_snr_gamma: int = 5, ): # Convert images to latent space latents = vae.encode(batch["pixel_values"]).latent_dist.sample() @@ -300,6 +301,16 @@ def loss_step( generator=generator ) + if perlin_strength != 0: + noise += perlin_strength * perlin_noise( + latents.shape, + res=1, + octaves=4, + dtype=latents.dtype, + device=latents.device, + generator=generator + ) + # Sample a random timestep for each image timesteps = torch.randint( 0, @@ -600,6 +611,7 @@ def train( global_step_offset: int = 0, with_prior_preservation: bool = False, prior_loss_weight: float = 1.0, + perlin_strength: float = 0.1, **kwargs, ): text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( @@ -635,6 +647,7 @@ def train( with_prior_preservation, prior_loss_weight, seed, + perlin_strength, ) if accelerator.is_main_process: -- cgit v1.2.3-70-g09d2