diff options
author | Volpeon <git@volpeon.ink> | 2023-03-25 17:49:30 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-25 17:49:30 +0100 |
commit | d69cc8f46f238e91e2f597cd301cc53b1d4b8bec (patch) | |
tree | 9c126915b658a46efcb0ecd5ac0f373f11a09ae7 /training | |
parent | Update (diff) | |
download | textual-inversion-diff-d69cc8f46f238e91e2f597cd301cc53b1d4b8bec.tar.gz textual-inversion-diff-d69cc8f46f238e91e2f597cd301cc53b1d4b8bec.tar.bz2 textual-inversion-diff-d69cc8f46f238e91e2f597cd301cc53b1d4b8bec.zip |
Fix training with guidance
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/training/functional.py b/training/functional.py index d285366..109845b 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -344,9 +344,15 @@ def loss_step( | |||
344 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 344 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
345 | 345 | ||
346 | if guidance_scale != 0: | 346 | if guidance_scale != 0: |
347 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 347 | uncond_encoder_hidden_states = get_extended_embeddings( |
348 | model_pred_uncond, model_pred_text = torch.chunk(model_pred, 2, dim=0) | 348 | text_encoder, |
349 | model_pred = model_pred_uncond + guidance_scale * (model_pred_text - model_pred_uncond) | 349 | batch["negative_input_ids"], |
350 | batch["negative_attention_mask"] | ||
351 | ) | ||
352 | uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) | ||
353 | |||
354 | model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample | ||
355 | model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) | ||
350 | 356 | ||
351 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | 357 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
352 | elif prior_loss_weight != 0: | 358 | elif prior_loss_weight != 0: |