summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py35
1 files changed, 17 insertions, 18 deletions
diff --git a/training/functional.py b/training/functional.py
index be39776..ed8ae3a 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -168,8 +168,7 @@ def save_samples(
168 image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] 168 image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0]
169 image_grid.save(file_path, quality=85) 169 image_grid.save(file_path, quality=85)
170 170
171 del generator 171 del generator, pipeline
172 del pipeline
173 172
174 if torch.cuda.is_available(): 173 if torch.cuda.is_available():
175 torch.cuda.empty_cache() 174 torch.cuda.empty_cache()
@@ -398,31 +397,32 @@ def loss_step(
398 else: 397 else:
399 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 398 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
400 399
401 if disc is None: 400 acc = (model_pred == target).float().mean()
402 if guidance_scale == 0 and prior_loss_weight != 0:
403 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
404 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
405 target, target_prior = torch.chunk(target, 2, dim=0)
406 401
407 # Compute instance loss 402 if guidance_scale == 0 and prior_loss_weight != 0:
408 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 403 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
404 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
405 target, target_prior = torch.chunk(target, 2, dim=0)
409 406
410 # Compute prior loss 407 # Compute instance loss
411 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") 408 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
412 409
413 # Add the prior loss to the instance loss. 410 # Compute prior loss
414 loss = loss + prior_loss_weight * prior_loss 411 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none")
415 else:
416 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
417 412
418 loss = loss.mean([1, 2, 3]) 413 # Add the prior loss to the instance loss.
414 loss = loss + prior_loss_weight * prior_loss
419 else: 415 else:
416 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
417
418 loss = loss.mean([1, 2, 3])
419
420 if disc is not None:
420 rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps) 421 rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps)
421 rec_latent /= vae.config.scaling_factor 422 rec_latent /= vae.config.scaling_factor
422 rec_latent = rec_latent.to(dtype=vae.dtype) 423 rec_latent = rec_latent.to(dtype=vae.dtype)
423 rec = vae.decode(rec_latent).sample 424 rec = vae.decode(rec_latent).sample
424 loss = 1 - disc.get_score(rec) 425 loss = 1 - disc.get_score(rec)
425 del rec_latent, rec
426 426
427 if min_snr_gamma != 0: 427 if min_snr_gamma != 0:
428 snr = compute_snr(timesteps, noise_scheduler) 428 snr = compute_snr(timesteps, noise_scheduler)
@@ -432,7 +432,6 @@ def loss_step(
432 loss *= mse_loss_weights 432 loss *= mse_loss_weights
433 433
434 loss = loss.mean() 434 loss = loss.mean()
435 acc = (model_pred == target).float().mean()
436 435
437 return loss, acc, bsz 436 return loss, acc, bsz
438 437