summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-18 09:42:28 +0100
committerVolpeon <git@volpeon.ink>2023-03-18 09:42:28 +0100
commit8c9dd1a230daf8a662447465c32dcae46ecbbe5f (patch)
treee8ae48df2da9c3e32ff468f90e7301777f3e5206 /training/functional.py
parentFixed snr weight calculation (diff)
downloadtextual-inversion-diff-8c9dd1a230daf8a662447465c32dcae46ecbbe5f.tar.gz
textual-inversion-diff-8c9dd1a230daf8a662447465c32dcae46ecbbe5f.tar.bz2
textual-inversion-diff-8c9dd1a230daf8a662447465c32dcae46ecbbe5f.zip
Better SNR weighting
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/training/functional.py b/training/functional.py
index 27527ef..b9574ec 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -314,10 +314,10 @@ def loss_step(
314 elif noise_scheduler.config.prediction_type == "v_prediction": 314 elif noise_scheduler.config.prediction_type == "v_prediction":
315 target = noise_scheduler.get_velocity(latents, noise, timesteps) 315 target = noise_scheduler.get_velocity(latents, noise, timesteps)
316 316
317 alpha_t = noise_scheduler.alphas_cumprod[timesteps].float()[..., None, None, None].expand(bsz, 1, 1, 1) 317 alpha_t = noise_scheduler.alphas_cumprod[timesteps].float()
318 snr = alpha_t / (1 - alpha_t) 318 snr = alpha_t / (1 - alpha_t)
319 snr /= snr + 1 319 snr_weights = snr.fmin(torch.tensor([5], device=latents.device)) / (snr + 1)
320 snr_weights = snr.fmin(torch.tensor([5], device=latents.device)) 320 snr_weights = snr_weights[..., None, None, None]
321 else: 321 else:
322 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 322 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
323 323