diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 13 |
1 files changed, 0 insertions, 13 deletions
diff --git a/training/functional.py b/training/functional.py index 3c7848f..a3d1f08 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -29,11 +29,8 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
29 | from models.clip.embeddings import ManagedCLIPTextEmbeddings | 29 | from models.clip.embeddings import ManagedCLIPTextEmbeddings |
30 | from models.clip.util import get_extended_embeddings | 30 | from models.clip.util import get_extended_embeddings |
31 | from models.clip.tokenizer import MultiCLIPTokenizer | 31 | from models.clip.tokenizer import MultiCLIPTokenizer |
32 | from models.convnext.discriminator import ConvNeXtDiscriminator | ||
33 | from training.util import AverageMeter | 32 | from training.util import AverageMeter |
34 | from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler | 33 | from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler |
35 | from util.slerp import slerp | ||
36 | from util.noise import perlin_noise | ||
37 | 34 | ||
38 | 35 | ||
39 | def const(result=None): | 36 | def const(result=None): |
@@ -349,7 +346,6 @@ def loss_step( | |||
349 | prior_loss_weight: float, | 346 | prior_loss_weight: float, |
350 | seed: int, | 347 | seed: int, |
351 | input_pertubation: float, | 348 | input_pertubation: float, |
352 | disc: Optional[ConvNeXtDiscriminator], | ||
353 | min_snr_gamma: int, | 349 | min_snr_gamma: int, |
354 | step: int, | 350 | step: int, |
355 | batch: dict[str, Any], | 351 | batch: dict[str, Any], |
@@ -449,13 +445,6 @@ def loss_step( | |||
449 | 445 | ||
450 | loss = loss.mean([1, 2, 3]) | 446 | loss = loss.mean([1, 2, 3]) |
451 | 447 | ||
452 | if disc is not None: | ||
453 | rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps) | ||
454 | rec_latent = rec_latent / vae.config.scaling_factor | ||
455 | rec_latent = rec_latent.to(dtype=vae.dtype) | ||
456 | rec = vae.decode(rec_latent, return_dict=False)[0] | ||
457 | loss = 1 - disc.get_score(rec) | ||
458 | |||
459 | if min_snr_gamma != 0: | 448 | if min_snr_gamma != 0: |
460 | snr = compute_snr(timesteps, noise_scheduler) | 449 | snr = compute_snr(timesteps, noise_scheduler) |
461 | mse_loss_weights = ( | 450 | mse_loss_weights = ( |
@@ -741,7 +730,6 @@ def train( | |||
741 | guidance_scale: float = 0.0, | 730 | guidance_scale: float = 0.0, |
742 | prior_loss_weight: float = 1.0, | 731 | prior_loss_weight: float = 1.0, |
743 | input_pertubation: float = 0.1, | 732 | input_pertubation: float = 0.1, |
744 | disc: Optional[ConvNeXtDiscriminator] = None, | ||
745 | schedule_sampler: Optional[ScheduleSampler] = None, | 733 | schedule_sampler: Optional[ScheduleSampler] = None, |
746 | min_snr_gamma: int = 5, | 734 | min_snr_gamma: int = 5, |
747 | avg_loss: AverageMeter = AverageMeter(), | 735 | avg_loss: AverageMeter = AverageMeter(), |
@@ -803,7 +791,6 @@ def train( | |||
803 | prior_loss_weight, | 791 | prior_loss_weight, |
804 | seed, | 792 | seed, |
805 | input_pertubation, | 793 | input_pertubation, |
806 | disc, | ||
807 | min_snr_gamma, | 794 | min_snr_gamma, |
808 | ) | 795 | ) |
809 | 796 | ||