diff options
author | Volpeon <git@volpeon.ink> | 2023-03-17 15:57:46 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-17 15:57:46 +0100 |
commit | 2f3f3644f723f5c1500939c5dfe4cf4da81e4831 (patch) | |
tree | a3d39ff848586df9fe8780949342c5fbc79602f8 | |
parent | Fix loss=nan (diff) | |
download | textual-inversion-diff-2f3f3644f723f5c1500939c5dfe4cf4da81e4831.tar.gz textual-inversion-diff-2f3f3644f723f5c1500939c5dfe4cf4da81e4831.tar.bz2 textual-inversion-diff-2f3f3644f723f5c1500939c5dfe4cf4da81e4831.zip |
Fixed snr weight calculation
-rw-r--r-- | training/functional.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/training/functional.py b/training/functional.py index 1baf9c6..27527ef 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -309,10 +309,13 @@ def loss_step( | |||
309 | # Get the target for loss depending on the prediction type | 309 | # Get the target for loss depending on the prediction type |
310 | if noise_scheduler.config.prediction_type == "epsilon": | 310 | if noise_scheduler.config.prediction_type == "epsilon": |
311 | target = noise | 311 | target = noise |
312 | |||
312 | snr_weights = 1 | 313 | snr_weights = 1 |
313 | elif noise_scheduler.config.prediction_type == "v_prediction": | 314 | elif noise_scheduler.config.prediction_type == "v_prediction": |
314 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | 315 | target = noise_scheduler.get_velocity(latents, noise, timesteps) |
315 | snr = target / (1 - target) | 316 | |
317 | alpha_t = noise_scheduler.alphas_cumprod[timesteps].float()[..., None, None, None].expand(bsz, 1, 1, 1) | ||
318 | snr = alpha_t / (1 - alpha_t) | ||
316 | snr /= snr + 1 | 319 | snr /= snr + 1 |
317 | snr_weights = snr.fmin(torch.tensor([5], device=latents.device)) | 320 | snr_weights = snr.fmin(torch.tensor([5], device=latents.device)) |
318 | else: | 321 | else: |