summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-19 14:37:07 +0100
committerVolpeon <git@volpeon.ink>2023-03-19 14:37:07 +0100
commit6c83a18aa8cf1d0d2a972bc8393584eb61b9deac (patch)
treee5c40656a6509abfbe7a014f8af0ab523c8c834c /training/functional.py
parentNew loss weighting from arxiv.org:2204.00227 (diff)
downloadtextual-inversion-diff-6c83a18aa8cf1d0d2a972bc8393584eb61b9deac.tar.gz
textual-inversion-diff-6c83a18aa8cf1d0d2a972bc8393584eb61b9deac.tar.bz2
textual-inversion-diff-6c83a18aa8cf1d0d2a972bc8393584eb61b9deac.zip
Restore min SNR
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py25
1 files changed, 12 insertions, 13 deletions
diff --git a/training/functional.py b/training/functional.py
index 15b95ba..8dc2b9f 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -261,7 +261,8 @@ def loss_step(
261 seed: int, 261 seed: int,
262 step: int, 262 step: int,
263 batch: dict[str, Any], 263 batch: dict[str, Any],
264 eval: bool = False 264 eval: bool = False,
265 min_snr_gamma: int = 5
265): 266):
266 # Convert images to latent space 267 # Convert images to latent space
267 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 268 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
@@ -307,23 +308,21 @@ def loss_step(
307 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 308 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
308 309
309 # Get the target for loss depending on the prediction type 310 # Get the target for loss depending on the prediction type
311 alpha_t = noise_scheduler.alphas_cumprod[timesteps].float()
312 snr = alpha_t / (1 - alpha_t)
313 min_snr = snr.clamp(max=min_snr_gamma)
314
310 if noise_scheduler.config.prediction_type == "epsilon": 315 if noise_scheduler.config.prediction_type == "epsilon":
311 target = noise 316 target = noise
312 317 loss_weight = min_snr / snr
313 snr_weights = 1
314 elif noise_scheduler.config.prediction_type == "v_prediction": 318 elif noise_scheduler.config.prediction_type == "v_prediction":
315 target = noise_scheduler.get_velocity(latents, noise, timesteps) 319 target = noise_scheduler.get_velocity(latents, noise, timesteps)
316 320 loss_weight = min_snr / (snr + 1)
317 p2_gamma = 1
318 p2_k = 1
319
320 alpha_t = noise_scheduler.alphas_cumprod[timesteps].float()
321 snr = 1.0 / (1 - alpha_t) - 1
322 snr_weights = 1 / (p2_k + snr) ** p2_gamma
323 snr_weights = snr_weights[..., None, None, None]
324 else: 321 else:
325 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 322 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
326 323
324 loss_weight = loss_weight[..., None, None, None]
325
327 if with_prior_preservation: 326 if with_prior_preservation:
328 # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 327 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
329 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 328 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
@@ -340,7 +339,7 @@ def loss_step(
340 else: 339 else:
341 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 340 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
342 341
343 loss = (snr_weights * loss).mean([1, 2, 3]).mean() 342 loss = (loss_weight * loss).mean([1, 2, 3]).mean()
344 acc = (model_pred == target).float().mean() 343 acc = (model_pred == target).float().mean()
345 344
346 return loss, acc, bsz 345 return loss, acc, bsz
@@ -413,7 +412,7 @@ def train_loop(
413 try: 412 try:
414 for epoch in range(num_epochs): 413 for epoch in range(num_epochs):
415 if accelerator.is_main_process: 414 if accelerator.is_main_process:
416 if epoch % sample_frequency == 0: 415 if epoch % sample_frequency == 0 and epoch != 0:
417 local_progress_bar.clear() 416 local_progress_bar.clear()
418 global_progress_bar.clear() 417 global_progress_bar.clear()
419 418