From 3784a976d3bb960d370854fc213a3173ae697acc Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 5 Apr 2023 19:31:47 +0200 Subject: MinSNR code from diffusers --- training/functional.py | 49 ++++++++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 19 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 68071bc..06848cb 100644 --- a/training/functional.py +++ b/training/functional.py @@ -251,23 +251,29 @@ def add_placeholder_tokens( return placeholder_token_ids, initializer_token_ids -def snr_weight(noisy_latents, latents, gamma): - if gamma: - sigma = torch.sub(noisy_latents, latents) - zeros = torch.zeros_like(sigma) - alpha_mean_sq = F.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3]) - sigma_mean_sq = F.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) - snr = torch.div(alpha_mean_sq, sigma_mean_sq) - gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) - snr_weight = torch.fmin(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() - return snr_weight - - return torch.tensor( - [1], - dtype=latents.dtype, - layout=latents.layout, - device=latents.device, - ) +def compute_snr(timesteps, noise_scheduler): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr def make_solid_image(color: float, shape, vae, dtype, device, generator): @@ -418,9 +424,14 @@ def loss_step( loss = loss.mean([1, 2, 3]) - loss_weight = snr_weight(noisy_latents, latents, min_snr_gamma) - loss = (loss_weight * loss).mean() + if min_snr_gamma != 0: + snr = compute_snr(timesteps, noise_scheduler) + mse_loss_weights = ( + torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + ) + loss *= mse_loss_weights + loss = loss.mean() acc = (model_pred == target).float().mean() return loss, acc, bsz -- cgit v1.2.3-54-g00ecf