summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py13
1 files changed, 0 insertions, 13 deletions
diff --git a/training/functional.py b/training/functional.py
index 3c7848f..a3d1f08 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -29,11 +29,8 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
29from models.clip.embeddings import ManagedCLIPTextEmbeddings 29from models.clip.embeddings import ManagedCLIPTextEmbeddings
30from models.clip.util import get_extended_embeddings 30from models.clip.util import get_extended_embeddings
31from models.clip.tokenizer import MultiCLIPTokenizer 31from models.clip.tokenizer import MultiCLIPTokenizer
32from models.convnext.discriminator import ConvNeXtDiscriminator
33from training.util import AverageMeter 32from training.util import AverageMeter
34from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler 33from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler
35from util.slerp import slerp
36from util.noise import perlin_noise
37 34
38 35
39def const(result=None): 36def const(result=None):
@@ -349,7 +346,6 @@ def loss_step(
349 prior_loss_weight: float, 346 prior_loss_weight: float,
350 seed: int, 347 seed: int,
351 input_pertubation: float, 348 input_pertubation: float,
352 disc: Optional[ConvNeXtDiscriminator],
353 min_snr_gamma: int, 349 min_snr_gamma: int,
354 step: int, 350 step: int,
355 batch: dict[str, Any], 351 batch: dict[str, Any],
@@ -449,13 +445,6 @@ def loss_step(
449 445
450 loss = loss.mean([1, 2, 3]) 446 loss = loss.mean([1, 2, 3])
451 447
452 if disc is not None:
453 rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps)
454 rec_latent = rec_latent / vae.config.scaling_factor
455 rec_latent = rec_latent.to(dtype=vae.dtype)
456 rec = vae.decode(rec_latent, return_dict=False)[0]
457 loss = 1 - disc.get_score(rec)
458
459 if min_snr_gamma != 0: 448 if min_snr_gamma != 0:
460 snr = compute_snr(timesteps, noise_scheduler) 449 snr = compute_snr(timesteps, noise_scheduler)
461 mse_loss_weights = ( 450 mse_loss_weights = (
@@ -741,7 +730,6 @@ def train(
741 guidance_scale: float = 0.0, 730 guidance_scale: float = 0.0,
742 prior_loss_weight: float = 1.0, 731 prior_loss_weight: float = 1.0,
743 input_pertubation: float = 0.1, 732 input_pertubation: float = 0.1,
744 disc: Optional[ConvNeXtDiscriminator] = None,
745 schedule_sampler: Optional[ScheduleSampler] = None, 733 schedule_sampler: Optional[ScheduleSampler] = None,
746 min_snr_gamma: int = 5, 734 min_snr_gamma: int = 5,
747 avg_loss: AverageMeter = AverageMeter(), 735 avg_loss: AverageMeter = AverageMeter(),
@@ -803,7 +791,6 @@ def train(
803 prior_loss_weight, 791 prior_loss_weight,
804 seed, 792 seed,
805 input_pertubation, 793 input_pertubation,
806 disc,
807 min_snr_gamma, 794 min_snr_gamma,
808 ) 795 )
809 796