diff options
Diffstat (limited to 'training/functional.py')
| -rw-r--r-- | training/functional.py | 35 |
1 files changed, 25 insertions, 10 deletions
diff --git a/training/functional.py b/training/functional.py index 8dc2b9f..43ee356 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -251,6 +251,25 @@ def add_placeholder_tokens( | |||
| 251 | return placeholder_token_ids, initializer_token_ids | 251 | return placeholder_token_ids, initializer_token_ids |
| 252 | 252 | ||
| 253 | 253 | ||
| 254 | def snr_weight(noisy_latents, latents, gamma): | ||
| 255 | if gamma: | ||
| 256 | sigma = torch.sub(noisy_latents, latents) | ||
| 257 | zeros = torch.zeros_like(sigma) | ||
| 258 | alpha_mean_sq = F.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3]) | ||
| 259 | sigma_mean_sq = F.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) | ||
| 260 | snr = torch.div(alpha_mean_sq, sigma_mean_sq) | ||
| 261 | gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) | ||
| 262 | snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() | ||
| 263 | return snr_weight | ||
| 264 | |||
| 265 | return torch.tensor( | ||
| 266 | [1], | ||
| 267 | dtype=latents.dtype, | ||
| 268 | layout=latents.layout, | ||
| 269 | device=latents.device, | ||
| 270 | ) | ||
| 271 | |||
| 272 | |||
| 254 | def loss_step( | 273 | def loss_step( |
| 255 | vae: AutoencoderKL, | 274 | vae: AutoencoderKL, |
| 256 | noise_scheduler: SchedulerMixin, | 275 | noise_scheduler: SchedulerMixin, |
| @@ -308,21 +327,13 @@ def loss_step( | |||
| 308 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 327 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 309 | 328 | ||
| 310 | # Get the target for loss depending on the prediction type | 329 | # Get the target for loss depending on the prediction type |
| 311 | alpha_t = noise_scheduler.alphas_cumprod[timesteps].float() | ||
| 312 | snr = alpha_t / (1 - alpha_t) | ||
| 313 | min_snr = snr.clamp(max=min_snr_gamma) | ||
| 314 | |||
| 315 | if noise_scheduler.config.prediction_type == "epsilon": | 330 | if noise_scheduler.config.prediction_type == "epsilon": |
| 316 | target = noise | 331 | target = noise |
| 317 | loss_weight = min_snr / snr | ||
| 318 | elif noise_scheduler.config.prediction_type == "v_prediction": | 332 | elif noise_scheduler.config.prediction_type == "v_prediction": |
| 319 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | 333 | target = noise_scheduler.get_velocity(latents, noise, timesteps) |
| 320 | loss_weight = min_snr / (snr + 1) | ||
| 321 | else: | 334 | else: |
| 322 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 335 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
| 323 | 336 | ||
| 324 | loss_weight = loss_weight[..., None, None, None] | ||
| 325 | |||
| 326 | if with_prior_preservation: | 337 | if with_prior_preservation: |
| 327 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 338 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
| 328 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 339 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
| @@ -339,7 +350,11 @@ def loss_step( | |||
| 339 | else: | 350 | else: |
| 340 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | 351 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
| 341 | 352 | ||
| 342 | loss = (loss_weight * loss).mean([1, 2, 3]).mean() | 353 | loss = loss.mean([1, 2, 3]) |
| 354 | |||
| 355 | loss_weight = snr_weight(noisy_latents, latents, min_snr_gamma) | ||
| 356 | loss = (loss_weight * loss).mean() | ||
| 357 | |||
| 343 | acc = (model_pred == target).float().mean() | 358 | acc = (model_pred == target).float().mean() |
| 344 | 359 | ||
| 345 | return loss, acc, bsz | 360 | return loss, acc, bsz |
| @@ -412,7 +427,7 @@ def train_loop( | |||
| 412 | try: | 427 | try: |
| 413 | for epoch in range(num_epochs): | 428 | for epoch in range(num_epochs): |
| 414 | if accelerator.is_main_process: | 429 | if accelerator.is_main_process: |
| 415 | if epoch % sample_frequency == 0 and epoch != 0: | 430 | if epoch % sample_frequency == 0: |
| 416 | local_progress_bar.clear() | 431 | local_progress_bar.clear() |
| 417 | global_progress_bar.clear() | 432 | global_progress_bar.clear() |
| 418 | 433 | ||
