diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 19 |
1 files changed, 9 insertions, 10 deletions
diff --git a/training/functional.py b/training/functional.py index 109845b..a2aa24e 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -335,14 +335,6 @@ def loss_step( | |||
335 | # Predict the noise residual | 335 | # Predict the noise residual |
336 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 336 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
337 | 337 | ||
338 | # Get the target for loss depending on the prediction type | ||
339 | if noise_scheduler.config.prediction_type == "epsilon": | ||
340 | target = noise | ||
341 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
342 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
343 | else: | ||
344 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
345 | |||
346 | if guidance_scale != 0: | 338 | if guidance_scale != 0: |
347 | uncond_encoder_hidden_states = get_extended_embeddings( | 339 | uncond_encoder_hidden_states = get_extended_embeddings( |
348 | text_encoder, | 340 | text_encoder, |
@@ -354,8 +346,15 @@ def loss_step( | |||
354 | model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample | 346 | model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample |
355 | model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) | 347 | model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) |
356 | 348 | ||
357 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | 349 | # Get the target for loss depending on the prediction type |
358 | elif prior_loss_weight != 0: | 350 | if noise_scheduler.config.prediction_type == "epsilon": |
351 | target = noise | ||
352 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
353 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
354 | else: | ||
355 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
356 | |||
357 | if guidance_scale == 0 and prior_loss_weight != 0: | ||
359 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 358 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
360 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 359 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
361 | target, target_prior = torch.chunk(target, 2, dim=0) | 360 | target, target_prior = torch.chunk(target, 2, dim=0) |