diff options
Diffstat (limited to 'training/functional.py')
-rw-r--r-- | training/functional.py | 20 |
1 files changed, 12 insertions, 8 deletions
diff --git a/training/functional.py b/training/functional.py index e7e1eb3..eae5681 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -25,6 +25,7 @@ from models.clip.util import get_extended_embeddings | |||
25 | from models.clip.tokenizer import MultiCLIPTokenizer | 25 | from models.clip.tokenizer import MultiCLIPTokenizer |
26 | from models.convnext.discriminator import ConvNeXtDiscriminator | 26 | from models.convnext.discriminator import ConvNeXtDiscriminator |
27 | from training.util import AverageMeter | 27 | from training.util import AverageMeter |
28 | from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler | ||
28 | from util.slerp import slerp | 29 | from util.slerp import slerp |
29 | 30 | ||
30 | 31 | ||
@@ -318,6 +319,7 @@ def get_original( | |||
318 | def loss_step( | 319 | def loss_step( |
319 | vae: AutoencoderKL, | 320 | vae: AutoencoderKL, |
320 | noise_scheduler: SchedulerMixin, | 321 | noise_scheduler: SchedulerMixin, |
322 | schedule_sampler: ScheduleSampler, | ||
321 | unet: UNet2DConditionModel, | 323 | unet: UNet2DConditionModel, |
322 | text_encoder: CLIPTextModel, | 324 | text_encoder: CLIPTextModel, |
323 | guidance_scale: float, | 325 | guidance_scale: float, |
@@ -362,14 +364,7 @@ def loss_step( | |||
362 | new_noise = noise + input_pertubation * torch.randn_like(noise) | 364 | new_noise = noise + input_pertubation * torch.randn_like(noise) |
363 | 365 | ||
364 | # Sample a random timestep for each image | 366 | # Sample a random timestep for each image |
365 | timesteps = torch.randint( | 367 | timesteps, weights = schedule_sampler.sample(bsz, latents.device) |
366 | 0, | ||
367 | noise_scheduler.config.num_train_timesteps, | ||
368 | (bsz,), | ||
369 | generator=generator, | ||
370 | device=latents.device, | ||
371 | ) | ||
372 | timesteps = timesteps.long() | ||
373 | 368 | ||
374 | # Add noise to the latents according to the noise magnitude at each timestep | 369 | # Add noise to the latents according to the noise magnitude at each timestep |
375 | # (this is the forward diffusion process) | 370 | # (this is the forward diffusion process) |
@@ -443,6 +438,10 @@ def loss_step( | |||
443 | ) | 438 | ) |
444 | loss = loss * mse_loss_weights | 439 | loss = loss * mse_loss_weights |
445 | 440 | ||
441 | if isinstance(schedule_sampler, LossAwareSampler): | ||
442 | schedule_sampler.update_with_all_losses(timesteps, loss.detach()) | ||
443 | |||
444 | loss = loss * weights | ||
446 | loss = loss.mean() | 445 | loss = loss.mean() |
447 | 446 | ||
448 | return loss, acc, bsz | 447 | return loss, acc, bsz |
@@ -694,6 +693,7 @@ def train( | |||
694 | offset_noise_strength: float = 0.01, | 693 | offset_noise_strength: float = 0.01, |
695 | input_pertubation: float = 0.1, | 694 | input_pertubation: float = 0.1, |
696 | disc: Optional[ConvNeXtDiscriminator] = None, | 695 | disc: Optional[ConvNeXtDiscriminator] = None, |
696 | schedule_sampler: Optional[ScheduleSampler] = None, | ||
697 | min_snr_gamma: int = 5, | 697 | min_snr_gamma: int = 5, |
698 | avg_loss: AverageMeter = AverageMeter(), | 698 | avg_loss: AverageMeter = AverageMeter(), |
699 | avg_acc: AverageMeter = AverageMeter(), | 699 | avg_acc: AverageMeter = AverageMeter(), |
@@ -725,10 +725,14 @@ def train( | |||
725 | **kwargs, | 725 | **kwargs, |
726 | ) | 726 | ) |
727 | 727 | ||
728 | if schedule_sampler is None: | ||
729 | schedule_sampler = UniformSampler(noise_scheduler.config.num_train_timesteps) | ||
730 | |||
728 | loss_step_ = partial( | 731 | loss_step_ = partial( |
729 | loss_step, | 732 | loss_step, |
730 | vae, | 733 | vae, |
731 | noise_scheduler, | 734 | noise_scheduler, |
735 | schedule_sampler, | ||
732 | unet, | 736 | unet, |
733 | text_encoder, | 737 | text_encoder, |
734 | guidance_scale, | 738 | guidance_scale, |