From e32b4d4c04a31b22051740e5f26e16960464f787 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 3 Mar 2023 18:53:15 +0100 Subject: Implemented different noise offset --- training/functional.py | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index 2d582bf..36269f0 100644 --- a/training/functional.py +++ b/training/functional.py @@ -253,7 +253,7 @@ def loss_step( text_encoder: CLIPTextModel, with_prior_preservation: bool, prior_loss_weight: float, - low_freq_noise: float, + noise_offset: float, seed: int, step: int, batch: dict[str, Any], @@ -268,30 +268,19 @@ def loss_step( generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None # Sample noise that we'll add to the latents - noise = torch.randn( - latents.shape, + offsets = noise_offset * torch.randn( + latents.shape[0], 1, 1, 1, dtype=latents.dtype, layout=latents.layout, device=latents.device, generator=generator + ).expand(latents.shape) + noise = torch.normal( + mean=offsets, + std=1, + generator=generator, ) - if low_freq_noise != 0: - low_freq_factor = low_freq_noise * torch.randn( - latents.shape[0], 1, 1, 1, - dtype=latents.dtype, - layout=latents.layout, - device=latents.device, - generator=generator - ) - noise = noise * (1 - low_freq_factor) + low_freq_factor * torch.randn( - latents.shape[0], latents.shape[1], 1, 1, - dtype=latents.dtype, - layout=latents.layout, - device=latents.device, - generator=generator - ) - # Sample a random timestep for each image timesteps = torch.randint( 0, @@ -576,7 +565,7 @@ def train( global_step_offset: int = 0, with_prior_preservation: bool = False, prior_loss_weight: float = 1.0, - low_freq_noise: float = 0.1, + noise_offset: float = 0.2, **kwargs, ): text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( @@ -611,7 +600,7 @@ def train( text_encoder, with_prior_preservation, prior_loss_weight, - low_freq_noise, + noise_offset, seed, ) -- cgit v1.2.3-54-g00ecf