From 19ae465203c8dcc0b1179584db632015362b5e44 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 26 Mar 2023 14:27:54 +0200 Subject: Improved inverted tokens --- training/functional.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 109845b..a2aa24e 100644 --- a/training/functional.py +++ b/training/functional.py @@ -335,14 +335,6 @@ def loss_step( # Predict the noise residual model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - # Get the target for loss depending on the prediction type - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - if guidance_scale != 0: uncond_encoder_hidden_states = get_extended_embeddings( text_encoder, @@ -354,8 +346,15 @@ def loss_step( model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") - elif prior_loss_weight != 0: + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if guidance_scale == 0 and prior_loss_weight != 0: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) -- cgit v1.2.3-70-g09d2