summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-26 14:27:54 +0200
committerVolpeon <git@volpeon.ink>2023-03-26 14:27:54 +0200
commit19ae465203c8dcc0b1179584db632015362b5e44 (patch)
treead6d45e78826f525c336927e4269197667f1f354 /training
parentFix training with guidance (diff)
downloadtextual-inversion-diff-19ae465203c8dcc0b1179584db632015362b5e44.tar.gz
textual-inversion-diff-19ae465203c8dcc0b1179584db632015362b5e44.tar.bz2
textual-inversion-diff-19ae465203c8dcc0b1179584db632015362b5e44.zip
Improved inverted tokens
Diffstat (limited to 'training')
-rw-r--r--training/functional.py19
1 files changed, 9 insertions, 10 deletions
diff --git a/training/functional.py b/training/functional.py
index 109845b..a2aa24e 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -335,14 +335,6 @@ def loss_step(
335 # Predict the noise residual 335 # Predict the noise residual
336 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 336 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
337 337
338 # Get the target for loss depending on the prediction type
339 if noise_scheduler.config.prediction_type == "epsilon":
340 target = noise
341 elif noise_scheduler.config.prediction_type == "v_prediction":
342 target = noise_scheduler.get_velocity(latents, noise, timesteps)
343 else:
344 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
345
346 if guidance_scale != 0: 338 if guidance_scale != 0:
347 uncond_encoder_hidden_states = get_extended_embeddings( 339 uncond_encoder_hidden_states = get_extended_embeddings(
348 text_encoder, 340 text_encoder,
@@ -354,8 +346,15 @@ def loss_step(
354 model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample 346 model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample
355 model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) 347 model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond)
356 348
357 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 349 # Get the target for loss depending on the prediction type
358 elif prior_loss_weight != 0: 350 if noise_scheduler.config.prediction_type == "epsilon":
351 target = noise
352 elif noise_scheduler.config.prediction_type == "v_prediction":
353 target = noise_scheduler.get_velocity(latents, noise, timesteps)
354 else:
355 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
356
357 if guidance_scale == 0 and prior_loss_weight != 0:
359 # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 358 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
360 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 359 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
361 target, target_prior = torch.chunk(target, 2, dim=0) 360 target, target_prior = torch.chunk(target, 2, dim=0)