summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-17 15:57:46 +0100
committerVolpeon <git@volpeon.ink>2023-03-17 15:57:46 +0100
commit2f3f3644f723f5c1500939c5dfe4cf4da81e4831 (patch)
treea3d39ff848586df9fe8780949342c5fbc79602f8 /training/functional.py
parentFix loss=nan (diff)
downloadtextual-inversion-diff-2f3f3644f723f5c1500939c5dfe4cf4da81e4831.tar.gz
textual-inversion-diff-2f3f3644f723f5c1500939c5dfe4cf4da81e4831.tar.bz2
textual-inversion-diff-2f3f3644f723f5c1500939c5dfe4cf4da81e4831.zip
Fixed snr weight calculation
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/training/functional.py b/training/functional.py
index 1baf9c6..27527ef 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -309,10 +309,13 @@ def loss_step(
309 # Get the target for loss depending on the prediction type 309 # Get the target for loss depending on the prediction type
310 if noise_scheduler.config.prediction_type == "epsilon": 310 if noise_scheduler.config.prediction_type == "epsilon":
311 target = noise 311 target = noise
312
312 snr_weights = 1 313 snr_weights = 1
313 elif noise_scheduler.config.prediction_type == "v_prediction": 314 elif noise_scheduler.config.prediction_type == "v_prediction":
314 target = noise_scheduler.get_velocity(latents, noise, timesteps) 315 target = noise_scheduler.get_velocity(latents, noise, timesteps)
315 snr = target / (1 - target) 316
317 alpha_t = noise_scheduler.alphas_cumprod[timesteps].float()[..., None, None, None].expand(bsz, 1, 1, 1)
318 snr = alpha_t / (1 - alpha_t)
316 snr /= snr + 1 319 snr /= snr + 1
317 snr_weights = snr.fmin(torch.tensor([5], device=latents.device)) 320 snr_weights = snr.fmin(torch.tensor([5], device=latents.device))
318 else: 321 else: