diff options
author | Volpeon <git@volpeon.ink> | 2023-03-19 14:37:07 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-19 14:37:07 +0100 |
commit | 6c83a18aa8cf1d0d2a972bc8393584eb61b9deac (patch) | |
tree | e5c40656a6509abfbe7a014f8af0ab523c8c834c | |
parent | New loss weighting from arxiv.org:2204.00227 (diff) | |
download | textual-inversion-diff-6c83a18aa8cf1d0d2a972bc8393584eb61b9deac.tar.gz textual-inversion-diff-6c83a18aa8cf1d0d2a972bc8393584eb61b9deac.tar.bz2 textual-inversion-diff-6c83a18aa8cf1d0d2a972bc8393584eb61b9deac.zip |
Restore min SNR
-rw-r--r-- | training/functional.py | 25 |
1 files changed, 12 insertions, 13 deletions
diff --git a/training/functional.py b/training/functional.py index 15b95ba..8dc2b9f 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -261,7 +261,8 @@ def loss_step( | |||
261 | seed: int, | 261 | seed: int, |
262 | step: int, | 262 | step: int, |
263 | batch: dict[str, Any], | 263 | batch: dict[str, Any], |
264 | eval: bool = False | 264 | eval: bool = False, |
265 | min_snr_gamma: int = 5 | ||
265 | ): | 266 | ): |
266 | # Convert images to latent space | 267 | # Convert images to latent space |
267 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 268 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
@@ -307,23 +308,21 @@ def loss_step( | |||
307 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 308 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
308 | 309 | ||
309 | # Get the target for loss depending on the prediction type | 310 | # Get the target for loss depending on the prediction type |
311 | alpha_t = noise_scheduler.alphas_cumprod[timesteps].float() | ||
312 | snr = alpha_t / (1 - alpha_t) | ||
313 | min_snr = snr.clamp(max=min_snr_gamma) | ||
314 | |||
310 | if noise_scheduler.config.prediction_type == "epsilon": | 315 | if noise_scheduler.config.prediction_type == "epsilon": |
311 | target = noise | 316 | target = noise |
312 | 317 | loss_weight = min_snr / snr | |
313 | snr_weights = 1 | ||
314 | elif noise_scheduler.config.prediction_type == "v_prediction": | 318 | elif noise_scheduler.config.prediction_type == "v_prediction": |
315 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | 319 | target = noise_scheduler.get_velocity(latents, noise, timesteps) |
316 | 320 | loss_weight = min_snr / (snr + 1) | |
317 | p2_gamma = 1 | ||
318 | p2_k = 1 | ||
319 | |||
320 | alpha_t = noise_scheduler.alphas_cumprod[timesteps].float() | ||
321 | snr = 1.0 / (1 - alpha_t) - 1 | ||
322 | snr_weights = 1 / (p2_k + snr) ** p2_gamma | ||
323 | snr_weights = snr_weights[..., None, None, None] | ||
324 | else: | 321 | else: |
325 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 322 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
326 | 323 | ||
324 | loss_weight = loss_weight[..., None, None, None] | ||
325 | |||
327 | if with_prior_preservation: | 326 | if with_prior_preservation: |
328 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 327 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
329 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 328 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
@@ -340,7 +339,7 @@ def loss_step( | |||
340 | else: | 339 | else: |
341 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | 340 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
342 | 341 | ||
343 | loss = (snr_weights * loss).mean([1, 2, 3]).mean() | 342 | loss = (loss_weight * loss).mean([1, 2, 3]).mean() |
344 | acc = (model_pred == target).float().mean() | 343 | acc = (model_pred == target).float().mean() |
345 | 344 | ||
346 | return loss, acc, bsz | 345 | return loss, acc, bsz |
@@ -413,7 +412,7 @@ def train_loop( | |||
413 | try: | 412 | try: |
414 | for epoch in range(num_epochs): | 413 | for epoch in range(num_epochs): |
415 | if accelerator.is_main_process: | 414 | if accelerator.is_main_process: |
416 | if epoch % sample_frequency == 0: | 415 | if epoch % sample_frequency == 0 and epoch != 0: |
417 | local_progress_bar.clear() | 416 | local_progress_bar.clear() |
418 | global_progress_bar.clear() | 417 | global_progress_bar.clear() |
419 | 418 | ||