diff options
author | Volpeon <git@volpeon.ink> | 2023-04-05 19:31:47 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-05 19:31:47 +0200 |
commit | 3784a976d3bb960d370854fc213a3173ae697acc (patch) | |
tree | 15e32dbbeb0fb44ca8075d41a21a962ac0acde2e /training | |
parent | Improved slerp noise offset: Dedicated black image instead of negative offset (diff) | |
download | textual-inversion-diff-3784a976d3bb960d370854fc213a3173ae697acc.tar.gz textual-inversion-diff-3784a976d3bb960d370854fc213a3173ae697acc.tar.bz2 textual-inversion-diff-3784a976d3bb960d370854fc213a3173ae697acc.zip |
MinSNR code from diffusers
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 49 |
1 files changed, 30 insertions, 19 deletions
diff --git a/training/functional.py b/training/functional.py index 68071bc..06848cb 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -251,23 +251,29 @@ def add_placeholder_tokens( | |||
251 | return placeholder_token_ids, initializer_token_ids | 251 | return placeholder_token_ids, initializer_token_ids |
252 | 252 | ||
253 | 253 | ||
254 | def snr_weight(noisy_latents, latents, gamma): | 254 | def compute_snr(timesteps, noise_scheduler): |
255 | if gamma: | 255 | """ |
256 | sigma = torch.sub(noisy_latents, latents) | 256 | Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 |
257 | zeros = torch.zeros_like(sigma) | 257 | """ |
258 | alpha_mean_sq = F.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3]) | 258 | alphas_cumprod = noise_scheduler.alphas_cumprod |
259 | sigma_mean_sq = F.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) | 259 | sqrt_alphas_cumprod = alphas_cumprod**0.5 |
260 | snr = torch.div(alpha_mean_sq, sigma_mean_sq) | 260 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 |
261 | gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) | 261 | |
262 | snr_weight = torch.fmin(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() | 262 | # Expand the tensors. |
263 | return snr_weight | 263 | # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 |
264 | 264 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | |
265 | return torch.tensor( | 265 | while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): |
266 | [1], | 266 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] |
267 | dtype=latents.dtype, | 267 | alpha = sqrt_alphas_cumprod.expand(timesteps.shape) |
268 | layout=latents.layout, | 268 | |
269 | device=latents.device, | 269 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() |
270 | ) | 270 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): |
271 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] | ||
272 | sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) | ||
273 | |||
274 | # Compute SNR. | ||
275 | snr = (alpha / sigma) ** 2 | ||
276 | return snr | ||
271 | 277 | ||
272 | 278 | ||
273 | def make_solid_image(color: float, shape, vae, dtype, device, generator): | 279 | def make_solid_image(color: float, shape, vae, dtype, device, generator): |
@@ -418,9 +424,14 @@ def loss_step( | |||
418 | 424 | ||
419 | loss = loss.mean([1, 2, 3]) | 425 | loss = loss.mean([1, 2, 3]) |
420 | 426 | ||
421 | loss_weight = snr_weight(noisy_latents, latents, min_snr_gamma) | 427 | if min_snr_gamma != 0: |
422 | loss = (loss_weight * loss).mean() | 428 | snr = compute_snr(timesteps, noise_scheduler) |
429 | mse_loss_weights = ( | ||
430 | torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr | ||
431 | ) | ||
432 | loss *= mse_loss_weights | ||
423 | 433 | ||
434 | loss = loss.mean() | ||
424 | acc = (model_pred == target).float().mean() | 435 | acc = (model_pred == target).float().mean() |
425 | 436 | ||
426 | return loss, acc, bsz | 437 | return loss, acc, bsz |