diff options
Diffstat (limited to 'training/functional.py')
-rw-r--r-- | training/functional.py | 7 |
1 files changed, 2 insertions, 5 deletions
diff --git a/training/functional.py b/training/functional.py index ebb48ab..015fe5e 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -259,7 +259,7 @@ def snr_weight(noisy_latents, latents, gamma): | |||
259 | sigma_mean_sq = F.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) | 259 | sigma_mean_sq = F.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) |
260 | snr = torch.div(alpha_mean_sq, sigma_mean_sq) | 260 | snr = torch.div(alpha_mean_sq, sigma_mean_sq) |
261 | gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) | 261 | gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) |
262 | snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() | 262 | snr_weight = torch.fmin(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() |
263 | return snr_weight | 263 | return snr_weight |
264 | 264 | ||
265 | return torch.tensor( | 265 | return torch.tensor( |
@@ -471,10 +471,7 @@ def train_loop( | |||
471 | "lr": lr_scheduler.get_last_lr()[0], | 471 | "lr": lr_scheduler.get_last_lr()[0], |
472 | } | 472 | } |
473 | if isDadaptation: | 473 | if isDadaptation: |
474 | logs["lr/d*lr"] = ( | 474 | logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] |
475 | optimizer.param_groups[0]["d"] * | ||
476 | optimizer.param_groups[0]["lr"] | ||
477 | ) | ||
478 | logs.update(on_log()) | 475 | logs.update(on_log()) |
479 | 476 | ||
480 | local_progress_bar.set_postfix(**logs) | 477 | local_progress_bar.set_postfix(**logs) |