summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py12
1 files changed, 9 insertions, 3 deletions
diff --git a/training/functional.py b/training/functional.py
index d285366..109845b 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -344,9 +344,15 @@ def loss_step(
344 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 344 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
345 345
346 if guidance_scale != 0: 346 if guidance_scale != 0:
347 # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 347 uncond_encoder_hidden_states = get_extended_embeddings(
348 model_pred_uncond, model_pred_text = torch.chunk(model_pred, 2, dim=0) 348 text_encoder,
349 model_pred = model_pred_uncond + guidance_scale * (model_pred_text - model_pred_uncond) 349 batch["negative_input_ids"],
350 batch["negative_attention_mask"]
351 )
352 uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype)
353
354 model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample
355 model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond)
350 356
351 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 357 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
352 elif prior_loss_weight != 0: 358 elif prior_loss_weight != 0: