diff options
author | Volpeon <git@volpeon.ink> | 2023-03-18 22:15:49 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-18 22:15:49 +0100 |
commit | 873bff68e5f37c85753df4240f3b2b2b88fc99a7 (patch) | |
tree | 7d27183c6ef5ea3f98fcfd247dd1a8e642386172 | |
parent | Better SNR weighting (diff) | |
download | textual-inversion-diff-873bff68e5f37c85753df4240f3b2b2b88fc99a7.tar.gz textual-inversion-diff-873bff68e5f37c85753df4240f3b2b2b88fc99a7.tar.bz2 textual-inversion-diff-873bff68e5f37c85753df4240f3b2b2b88fc99a7.zip |
New loss weighting from arxiv.org:2204.00227
-rw-r--r-- | training/functional.py | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/training/functional.py b/training/functional.py index b9574ec..15b95ba 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -314,9 +314,12 @@ def loss_step( | |||
314 | elif noise_scheduler.config.prediction_type == "v_prediction": | 314 | elif noise_scheduler.config.prediction_type == "v_prediction": |
315 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | 315 | target = noise_scheduler.get_velocity(latents, noise, timesteps) |
316 | 316 | ||
317 | p2_gamma = 1 | ||
318 | p2_k = 1 | ||
319 | |||
317 | alpha_t = noise_scheduler.alphas_cumprod[timesteps].float() | 320 | alpha_t = noise_scheduler.alphas_cumprod[timesteps].float() |
318 | snr = alpha_t / (1 - alpha_t) | 321 | snr = 1.0 / (1 - alpha_t) - 1 |
319 | snr_weights = snr.fmin(torch.tensor([5], device=latents.device)) / (snr + 1) | 322 | snr_weights = 1 / (p2_k + snr) ** p2_gamma |
320 | snr_weights = snr_weights[..., None, None, None] | 323 | snr_weights = snr_weights[..., None, None, None] |
321 | else: | 324 | else: |
322 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 325 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |