From 0767c7bc82645186159965c2a6be4278e33c6721 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 23 Mar 2023 11:07:57 +0100 Subject: Update --- training/functional.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index ebb48ab..015fe5e 100644 --- a/training/functional.py +++ b/training/functional.py @@ -259,7 +259,7 @@ def snr_weight(noisy_latents, latents, gamma): sigma_mean_sq = F.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) snr = torch.div(alpha_mean_sq, sigma_mean_sq) gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) - snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() + snr_weight = torch.fmin(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() return snr_weight return torch.tensor( @@ -471,10 +471,7 @@ def train_loop( "lr": lr_scheduler.get_last_lr()[0], } if isDadaptation: - logs["lr/d*lr"] = ( - optimizer.param_groups[0]["d"] * - optimizer.param_groups[0]["lr"] - ) + logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] logs.update(on_log()) local_progress_bar.set_postfix(**logs) -- cgit v1.2.3-54-g00ecf