From 58ea0d0967b7845f3fb9996e353d5c3918407f98 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 5 Apr 2023 16:03:31 +0200 Subject: New offset noise test --- training/functional.py | 97 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 89 insertions(+), 8 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 96ecbc1..e7f02cb 100644 --- a/training/functional.py +++ b/training/functional.py @@ -270,6 +270,64 @@ def snr_weight(noisy_latents, latents, gamma): ) +def slerp(v1, v2, t, DOT_THR=0.9995, to_cpu=False, zdim=-1): + """SLERP for pytorch tensors interpolating `v1` to `v2` with scale of `t`. + + `DOT_THR` determines when the vectors are too close to parallel. + If they are too close, then a regular linear interpolation is used. + + `to_cpu` is a flag that optionally computes SLERP on the CPU. + If the input tensors were on a GPU, it moves them back after the computation. + + `zdim` is the feature dimension over which to compute norms and find angles. + For example: if a sequence of 5 vectors is input with shape [5, 768] + Then `zdim = 1` or `zdim = -1` computes SLERP along the feature dim of 768. + + Theory Reference: + https://splines.readthedocs.io/en/latest/rotation/slerp.html + PyTorch reference: + https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3 + Numpy reference: + https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c + """ + + # check if we need to move to the cpu + if to_cpu: + orig_device = v1.device + v1, v2 = v1.to('cpu'), v2.to('cpu') + + # take the dot product between normalized vectors + v1_norm = v1 / torch.norm(v1, dim=zdim, keepdim=True) + v2_norm = v2 / torch.norm(v2, dim=zdim, keepdim=True) + dot = (v1_norm * v2_norm).sum(zdim) + + for _ in range(len(dot.shape), len(v1.shape)): + dot = dot[..., None] + + # if the vectors are too close, return a simple linear interpolation + if (torch.abs(dot) > DOT_THR).any(): + res = (1 - t) * v1 + t * v2 + else: + # compute the angle terms we need + theta = torch.acos(dot) + theta_t = theta * t + sin_theta = torch.sin(theta) + sin_theta_t = torch.sin(theta_t) + + # compute the sine scaling terms for the vectors + s1 = torch.sin(theta - theta_t) / sin_theta + s2 = sin_theta_t / sin_theta + + # interpolate the vectors + res = s1 * v1 + s2 * v2 + + # check if we need to move them back to the original device + if to_cpu: + res.to(orig_device) + + return res + + def loss_step( vae: AutoencoderKL, noise_scheduler: SchedulerMixin, @@ -279,10 +337,11 @@ def loss_step( prior_loss_weight: float, seed: int, offset_noise_strength: float, + min_snr_gamma: int, step: int, batch: dict[str, Any], + cache: dict[Any, Any], eval: bool = False, - min_snr_gamma: int = 5, ): images = batch["pixel_values"] generator = torch.Generator(device=images.device).manual_seed(seed + step) if eval else None @@ -302,13 +361,31 @@ def loss_step( ) if offset_noise_strength != 0: - offset_noise = torch.randn( - (latents.shape[0], latents.shape[1], 1, 1), + cache_key = f"img_white_{images.shape[2]}_{images.shape[3]}" + + if cache_key not in cache: + img_white = torch.tensor( + [[[[1]]]], + dtype=latents.dtype, + device=latents.device + ).expand(1, images.shape[1], images.shape[2], images.shape[3]) + img_white = img_white * 2 - 1 + img_white = vae.encode(img_white).latent_dist.sample(generator=generator) + img_white *= vae.config.scaling_factor + cache[cache_key] = img_white + else: + img_white = cache[cache_key] + + offset_strength = torch.rand( + (bsz, 1, 1, 1), dtype=latents.dtype, + layout=latents.layout, device=latents.device, generator=generator - ).expand(noise.shape) - noise += offset_noise_strength * offset_noise + ) + offset_strength = offset_noise_strength * (offset_strength * 2 - 1) + offset_strength = offset_strength.expand(noise.shape) + noise = slerp(noise, img_white.expand(noise.shape), offset_strength, zdim=(-1, -2)) # Sample a random timestep for each image timesteps = torch.randint( @@ -382,7 +459,8 @@ def loss_step( class LossCallable(Protocol): - def __call__(self, step: int, batch: dict[str, Any], eval: bool = False) -> Tuple[Any, Any, int]: ... + def __call__(self, step: int, batch: dict[Any, Any], cache: dict[str, Any], + eval: bool = False) -> Tuple[Any, Any, int]: ... def train_loop( @@ -407,6 +485,7 @@ def train_loop( num_val_steps = num_val_steps_per_epoch * num_epochs global_step = 0 + cache = {} avg_loss = AverageMeter() avg_acc = AverageMeter() @@ -476,7 +555,7 @@ def train_loop( with on_train(epoch): for step, batch in enumerate(train_dataloader): - loss, acc, bsz = loss_step(step, batch) + loss, acc, bsz = loss_step(step, batch, cache) loss /= gradient_accumulation_steps accelerator.backward(loss) @@ -541,7 +620,7 @@ def train_loop( with torch.inference_mode(), on_eval(): for step, batch in enumerate(val_dataloader): - loss, acc, bsz = loss_step(step, batch, True) + loss, acc, bsz = loss_step(step, batch, cache, True) loss = loss.detach_() acc = acc.detach_() @@ -633,6 +712,7 @@ def train( guidance_scale: float = 0.0, prior_loss_weight: float = 1.0, offset_noise_strength: float = 0.15, + min_snr_gamma: int = 5, **kwargs, ): text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( @@ -665,6 +745,7 @@ def train( prior_loss_weight, seed, offset_noise_strength, + min_snr_gamma, ) if accelerator.is_main_process: -- cgit v1.2.3-70-g09d2