From 881e20dd52cb68dbe6b8f0a78c82a4ffcf3dea6d Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Sat, 24 Jun 2023 20:33:23 +0200
Subject: New loss scaling

---
 training/functional.py | 108 +++++++++++++++++++------------------------------
 1 file changed, 41 insertions(+), 67 deletions(-)

(limited to 'training')

diff --git a/training/functional.py b/training/functional.py
index cc079ef..8917eb7 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -94,6 +94,8 @@ def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32):
         pretrained_model_name_or_path, subfolder="scheduler"
     )
 
+    prepare_scheduler_for_custom_training(noise_scheduler, "cuda")
+
     return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler
 
 
@@ -273,68 +275,39 @@ def add_placeholder_tokens(
     return placeholder_token_ids, initializer_token_ids
 
 
-def compute_snr(timesteps, noise_scheduler):
-    """
-    Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
-    """
-    alphas_cumprod = noise_scheduler.alphas_cumprod
-    sqrt_alphas_cumprod = alphas_cumprod**0.5
-    sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
-
-    # Expand the tensors.
-    # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
-    sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
-        timesteps
-    ].float()
-    while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
-        sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
-    alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
-
-    sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
-        device=timesteps.device
-    )[timesteps].float()
-    while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
-        sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
-    sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
-
-    # Compute SNR.
-    snr = (alpha / sigma) ** 2
-    return snr
-
-
-def get_original(
-    noise_scheduler, model_output, sample: torch.FloatTensor, timesteps: torch.IntTensor
-):
+def prepare_scheduler_for_custom_training(noise_scheduler, device):
+    if hasattr(noise_scheduler, "all_snr"):
+        return
+
     alphas_cumprod = noise_scheduler.alphas_cumprod
-    sqrt_alphas_cumprod = alphas_cumprod**0.5
-    sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
-
-    sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
-        timesteps
-    ].float()
-    while len(sqrt_alphas_cumprod.shape) < len(sample.shape):
-        sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
-    alpha = sqrt_alphas_cumprod.expand(sample.shape)
-
-    sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
-        device=timesteps.device
-    )[timesteps].float()
-    while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape):
-        sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
-    sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape)
+    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
+    sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
+    alpha = sqrt_alphas_cumprod
+    sigma = sqrt_one_minus_alphas_cumprod
+    all_snr = (alpha / sigma) ** 2
+
+    noise_scheduler.all_snr = all_snr.to(device)
+
+
+def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
+    snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
+    gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
+    snr_weight = torch.minimum(
+        gamma_over_snr, torch.ones_like(gamma_over_snr)
+    ).float()  # from paper
+    loss = loss * snr_weight
+    return loss
 
-    if noise_scheduler.config.prediction_type == "epsilon":
-        pred_original_sample = (sample - sigma * model_output) / alpha
-    elif noise_scheduler.config.prediction_type == "sample":
-        pred_original_sample = model_output
-    elif noise_scheduler.config.prediction_type == "v_prediction":
-        pred_original_sample = alpha * sample - sigma * model_output
-    else:
-        raise ValueError(
-            f"prediction_type given as {noise_scheduler.config.prediction_type} must be one of `epsilon`, `sample` or"
-            " `v_prediction`  for the DDPMScheduler."
-        )
-    return pred_original_sample
+
+def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
+    snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])  # batch_size
+    snr_t = torch.minimum(
+        snr_t, torch.ones_like(snr_t) * 1000
+    )  # if timestep is 0, snr_t is inf, so limit it to 1000
+    scale = snr_t / (snr_t + 1)
+
+    loss = loss * scale
+    return loss
 
 
 def loss_step(
@@ -347,6 +320,7 @@ def loss_step(
     seed: int,
     input_pertubation: float,
     min_snr_gamma: int,
+    scale_v_pred_loss_like_noise_pred: bool,
     step: int,
     batch: dict[str, Any],
     cache: dict[Any, Any],
@@ -433,14 +407,12 @@ def loss_step(
     loss = loss.mean([1, 2, 3])
 
     if min_snr_gamma != 0:
-        snr = compute_snr(timesteps, noise_scheduler)
-        mse_loss_weights = (
-            torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(
-                dim=1
-            )[0]
-            / snr
+        loss = apply_snr_weight(loss, timesteps, noise_scheduler, min_snr_gamma)
+
+    if scale_v_pred_loss_like_noise_pred:
+        loss = scale_v_prediction_loss_like_noise_prediction(
+            loss, timesteps, noise_scheduler
         )
-        loss = loss * mse_loss_weights
 
     if isinstance(schedule_sampler, LossAwareSampler):
         schedule_sampler.update_with_all_losses(timesteps, loss.detach())
@@ -726,6 +698,7 @@ def train(
     input_pertubation: float = 0.1,
     schedule_sampler: Optional[ScheduleSampler] = None,
     min_snr_gamma: int = 5,
+    scale_v_pred_loss_like_noise_pred: bool = True,
     avg_loss: AverageMeter = AverageMeter(),
     avg_acc: AverageMeter = AverageMeter(),
     avg_loss_val: AverageMeter = AverageMeter(),
@@ -785,6 +758,7 @@ def train(
         seed,
         input_pertubation,
         min_snr_gamma,
+        scale_v_pred_loss_like_noise_pred,
     )
 
     train_loop(
-- 
cgit v1.2.3-70-g09d2