From 7b04d813739c0b5595295dffdc86cc41108db2d3 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 6 May 2023 16:25:36 +0200 Subject: Update --- training/functional.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index e7e1eb3..eae5681 100644 --- a/training/functional.py +++ b/training/functional.py @@ -25,6 +25,7 @@ from models.clip.util import get_extended_embeddings from models.clip.tokenizer import MultiCLIPTokenizer from models.convnext.discriminator import ConvNeXtDiscriminator from training.util import AverageMeter +from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler from util.slerp import slerp @@ -318,6 +319,7 @@ def get_original( def loss_step( vae: AutoencoderKL, noise_scheduler: SchedulerMixin, + schedule_sampler: ScheduleSampler, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, guidance_scale: float, @@ -362,14 +364,7 @@ def loss_step( new_noise = noise + input_pertubation * torch.randn_like(noise) # Sample a random timestep for each image - timesteps = torch.randint( - 0, - noise_scheduler.config.num_train_timesteps, - (bsz,), - generator=generator, - device=latents.device, - ) - timesteps = timesteps.long() + timesteps, weights = schedule_sampler.sample(bsz, latents.device) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -443,6 +438,10 @@ def loss_step( ) loss = loss * mse_loss_weights + if isinstance(schedule_sampler, LossAwareSampler): + schedule_sampler.update_with_all_losses(timesteps, loss.detach()) + + loss = loss * weights loss = loss.mean() return loss, acc, bsz @@ -694,6 +693,7 @@ def train( offset_noise_strength: float = 0.01, input_pertubation: float = 0.1, disc: Optional[ConvNeXtDiscriminator] = None, + schedule_sampler: Optional[ScheduleSampler] = None, min_snr_gamma: int = 5, avg_loss: AverageMeter = AverageMeter(), avg_acc: AverageMeter = AverageMeter(), @@ -725,10 +725,14 @@ def train( **kwargs, ) + if schedule_sampler is None: + schedule_sampler = UniformSampler(noise_scheduler.config.num_train_timesteps) + loss_step_ = partial( loss_step, vae, noise_scheduler, + schedule_sampler, unet, text_encoder, guidance_scale, -- cgit v1.2.3-54-g00ecf