From 7c56bf2fa99a8f6b7e23bb66ef4a2364dec3fbbd Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 14 Feb 2023 11:02:41 +0100 Subject: Made low-freq noise configurable --- training/functional.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index 62b8260..a9c7a8a 100644 --- a/training/functional.py +++ b/training/functional.py @@ -256,6 +256,7 @@ def loss_step( text_encoder: CLIPTextModel, with_prior_preservation: bool, prior_loss_weight: float, + low_freq_noise: float, seed: int, step: int, batch: dict[str, Any], @@ -274,13 +275,15 @@ def loss_step( layout=latents.layout, device=latents.device, generator=generator - ) + 0.1 * torch.randn( - latents.shape[0], latents.shape[1], 1, 1, - dtype=latents.dtype, - layout=latents.layout, - device=latents.device, - generator=generator ) + if low_freq_noise > 0: + noise += low_freq_noise * torch.randn( + latents.shape[0], latents.shape[1], 1, 1, + dtype=latents.dtype, + layout=latents.layout, + device=latents.device, + generator=generator + ) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint( @@ -553,6 +556,7 @@ def train( global_step_offset: int = 0, with_prior_preservation: bool = False, prior_loss_weight: float = 1.0, + low_freq_noise: float = 0.05, **kwargs, ): text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( @@ -587,6 +591,7 @@ def train( text_encoder, with_prior_preservation, prior_loss_weight, + low_freq_noise, seed, ) -- cgit v1.2.3-54-g00ecf