diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 19 |
1 files changed, 9 insertions, 10 deletions
diff --git a/training/functional.py b/training/functional.py index 109845b..a2aa24e 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -335,14 +335,6 @@ def loss_step( | |||
| 335 | # Predict the noise residual | 335 | # Predict the noise residual |
| 336 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 336 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 337 | 337 | ||
| 338 | # Get the target for loss depending on the prediction type | ||
| 339 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 340 | target = noise | ||
| 341 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 342 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 343 | else: | ||
| 344 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 345 | |||
| 346 | if guidance_scale != 0: | 338 | if guidance_scale != 0: |
| 347 | uncond_encoder_hidden_states = get_extended_embeddings( | 339 | uncond_encoder_hidden_states = get_extended_embeddings( |
| 348 | text_encoder, | 340 | text_encoder, |
| @@ -354,8 +346,15 @@ def loss_step( | |||
| 354 | model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample | 346 | model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample |
| 355 | model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) | 347 | model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) |
| 356 | 348 | ||
| 357 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | 349 | # Get the target for loss depending on the prediction type |
| 358 | elif prior_loss_weight != 0: | 350 | if noise_scheduler.config.prediction_type == "epsilon": |
| 351 | target = noise | ||
| 352 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 353 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 354 | else: | ||
| 355 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 356 | |||
| 357 | if guidance_scale == 0 and prior_loss_weight != 0: | ||
| 359 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 358 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
| 360 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 359 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
| 361 | target, target_prior = torch.chunk(target, 2, dim=0) | 360 | target, target_prior = torch.chunk(target, 2, dim=0) |
