summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-18 22:15:49 +0100
committerVolpeon <git@volpeon.ink>2023-03-18 22:15:49 +0100
commit873bff68e5f37c85753df4240f3b2b2b88fc99a7 (patch)
tree7d27183c6ef5ea3f98fcfd247dd1a8e642386172
parentBetter SNR weighting (diff)
downloadtextual-inversion-diff-873bff68e5f37c85753df4240f3b2b2b88fc99a7.tar.gz
textual-inversion-diff-873bff68e5f37c85753df4240f3b2b2b88fc99a7.tar.bz2
textual-inversion-diff-873bff68e5f37c85753df4240f3b2b2b88fc99a7.zip
New loss weighting from arxiv.org:2204.00227
-rw-r--r--training/functional.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/training/functional.py b/training/functional.py
index b9574ec..15b95ba 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -314,9 +314,12 @@ def loss_step(
314 elif noise_scheduler.config.prediction_type == "v_prediction": 314 elif noise_scheduler.config.prediction_type == "v_prediction":
315 target = noise_scheduler.get_velocity(latents, noise, timesteps) 315 target = noise_scheduler.get_velocity(latents, noise, timesteps)
316 316
317 p2_gamma = 1
318 p2_k = 1
319
317 alpha_t = noise_scheduler.alphas_cumprod[timesteps].float() 320 alpha_t = noise_scheduler.alphas_cumprod[timesteps].float()
318 snr = alpha_t / (1 - alpha_t) 321 snr = 1.0 / (1 - alpha_t) - 1
319 snr_weights = snr.fmin(torch.tensor([5], device=latents.device)) / (snr + 1) 322 snr_weights = 1 / (p2_k + snr) ** p2_gamma
320 snr_weights = snr_weights[..., None, None, None] 323 snr_weights = snr_weights[..., None, None, None]
321 else: 324 else:
322 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 325 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")