summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-05 19:31:47 +0200
committerVolpeon <git@volpeon.ink>2023-04-05 19:31:47 +0200
commit3784a976d3bb960d370854fc213a3173ae697acc (patch)
tree15e32dbbeb0fb44ca8075d41a21a962ac0acde2e
parentImproved slerp noise offset: Dedicated black image instead of negative offset (diff)
downloadtextual-inversion-diff-3784a976d3bb960d370854fc213a3173ae697acc.tar.gz
textual-inversion-diff-3784a976d3bb960d370854fc213a3173ae697acc.tar.bz2
textual-inversion-diff-3784a976d3bb960d370854fc213a3173ae697acc.zip
MinSNR code from diffusers
-rw-r--r--training/functional.py49
1 files changed, 30 insertions, 19 deletions
diff --git a/training/functional.py b/training/functional.py
index 68071bc..06848cb 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -251,23 +251,29 @@ def add_placeholder_tokens(
251 return placeholder_token_ids, initializer_token_ids 251 return placeholder_token_ids, initializer_token_ids
252 252
253 253
254def snr_weight(noisy_latents, latents, gamma): 254def compute_snr(timesteps, noise_scheduler):
255 if gamma: 255 """
256 sigma = torch.sub(noisy_latents, latents) 256 Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
257 zeros = torch.zeros_like(sigma) 257 """
258 alpha_mean_sq = F.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3]) 258 alphas_cumprod = noise_scheduler.alphas_cumprod
259 sigma_mean_sq = F.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) 259 sqrt_alphas_cumprod = alphas_cumprod**0.5
260 snr = torch.div(alpha_mean_sq, sigma_mean_sq) 260 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
261 gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) 261
262 snr_weight = torch.fmin(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() 262 # Expand the tensors.
263 return snr_weight 263 # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
264 264 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
265 return torch.tensor( 265 while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
266 [1], 266 sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
267 dtype=latents.dtype, 267 alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
268 layout=latents.layout, 268
269 device=latents.device, 269 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
270 ) 270 while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
271 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
272 sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
273
274 # Compute SNR.
275 snr = (alpha / sigma) ** 2
276 return snr
271 277
272 278
273def make_solid_image(color: float, shape, vae, dtype, device, generator): 279def make_solid_image(color: float, shape, vae, dtype, device, generator):
@@ -418,9 +424,14 @@ def loss_step(
418 424
419 loss = loss.mean([1, 2, 3]) 425 loss = loss.mean([1, 2, 3])
420 426
421 loss_weight = snr_weight(noisy_latents, latents, min_snr_gamma) 427 if min_snr_gamma != 0:
422 loss = (loss_weight * loss).mean() 428 snr = compute_snr(timesteps, noise_scheduler)
429 mse_loss_weights = (
430 torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
431 )
432 loss *= mse_loss_weights
423 433
434 loss = loss.mean()
424 acc = (model_pred == target).float().mean() 435 acc = (model_pred == target).float().mean()
425 436
426 return loss, acc, bsz 437 return loss, acc, bsz