diff options
| author | Volpeon <git@volpeon.ink> | 2023-03-18 09:42:28 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-03-18 09:42:28 +0100 |
| commit | 8c9dd1a230daf8a662447465c32dcae46ecbbe5f (patch) | |
| tree | e8ae48df2da9c3e32ff468f90e7301777f3e5206 /training | |
| parent | Fixed snr weight calculation (diff) | |
| download | textual-inversion-diff-8c9dd1a230daf8a662447465c32dcae46ecbbe5f.tar.gz textual-inversion-diff-8c9dd1a230daf8a662447465c32dcae46ecbbe5f.tar.bz2 textual-inversion-diff-8c9dd1a230daf8a662447465c32dcae46ecbbe5f.zip | |
Better SNR weighting
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/training/functional.py b/training/functional.py index 27527ef..b9574ec 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -314,10 +314,10 @@ def loss_step( | |||
| 314 | elif noise_scheduler.config.prediction_type == "v_prediction": | 314 | elif noise_scheduler.config.prediction_type == "v_prediction": |
| 315 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | 315 | target = noise_scheduler.get_velocity(latents, noise, timesteps) |
| 316 | 316 | ||
| 317 | alpha_t = noise_scheduler.alphas_cumprod[timesteps].float()[..., None, None, None].expand(bsz, 1, 1, 1) | 317 | alpha_t = noise_scheduler.alphas_cumprod[timesteps].float() |
| 318 | snr = alpha_t / (1 - alpha_t) | 318 | snr = alpha_t / (1 - alpha_t) |
| 319 | snr /= snr + 1 | 319 | snr_weights = snr.fmin(torch.tensor([5], device=latents.device)) / (snr + 1) |
| 320 | snr_weights = snr.fmin(torch.tensor([5], device=latents.device)) | 320 | snr_weights = snr_weights[..., None, None, None] |
| 321 | else: | 321 | else: |
| 322 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 322 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
| 323 | 323 | ||
