summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-05-06 16:25:36 +0200
committerVolpeon <git@volpeon.ink>2023-05-06 16:25:36 +0200
commit7b04d813739c0b5595295dffdc86cc41108db2d3 (patch)
tree8958b612f5d3d665866770ad553e1004aa4b6fb8 /training/functional.py
parentUpdate (diff)
downloadtextual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.tar.gz
textual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.tar.bz2
textual-inversion-diff-7b04d813739c0b5595295dffdc86cc41108db2d3.zip
Update
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py20
1 files changed, 12 insertions, 8 deletions
diff --git a/training/functional.py b/training/functional.py
index e7e1eb3..eae5681 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -25,6 +25,7 @@ from models.clip.util import get_extended_embeddings
25from models.clip.tokenizer import MultiCLIPTokenizer 25from models.clip.tokenizer import MultiCLIPTokenizer
26from models.convnext.discriminator import ConvNeXtDiscriminator 26from models.convnext.discriminator import ConvNeXtDiscriminator
27from training.util import AverageMeter 27from training.util import AverageMeter
28from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler
28from util.slerp import slerp 29from util.slerp import slerp
29 30
30 31
@@ -318,6 +319,7 @@ def get_original(
318def loss_step( 319def loss_step(
319 vae: AutoencoderKL, 320 vae: AutoencoderKL,
320 noise_scheduler: SchedulerMixin, 321 noise_scheduler: SchedulerMixin,
322 schedule_sampler: ScheduleSampler,
321 unet: UNet2DConditionModel, 323 unet: UNet2DConditionModel,
322 text_encoder: CLIPTextModel, 324 text_encoder: CLIPTextModel,
323 guidance_scale: float, 325 guidance_scale: float,
@@ -362,14 +364,7 @@ def loss_step(
362 new_noise = noise + input_pertubation * torch.randn_like(noise) 364 new_noise = noise + input_pertubation * torch.randn_like(noise)
363 365
364 # Sample a random timestep for each image 366 # Sample a random timestep for each image
365 timesteps = torch.randint( 367 timesteps, weights = schedule_sampler.sample(bsz, latents.device)
366 0,
367 noise_scheduler.config.num_train_timesteps,
368 (bsz,),
369 generator=generator,
370 device=latents.device,
371 )
372 timesteps = timesteps.long()
373 368
374 # Add noise to the latents according to the noise magnitude at each timestep 369 # Add noise to the latents according to the noise magnitude at each timestep
375 # (this is the forward diffusion process) 370 # (this is the forward diffusion process)
@@ -443,6 +438,10 @@ def loss_step(
443 ) 438 )
444 loss = loss * mse_loss_weights 439 loss = loss * mse_loss_weights
445 440
441 if isinstance(schedule_sampler, LossAwareSampler):
442 schedule_sampler.update_with_all_losses(timesteps, loss.detach())
443
444 loss = loss * weights
446 loss = loss.mean() 445 loss = loss.mean()
447 446
448 return loss, acc, bsz 447 return loss, acc, bsz
@@ -694,6 +693,7 @@ def train(
694 offset_noise_strength: float = 0.01, 693 offset_noise_strength: float = 0.01,
695 input_pertubation: float = 0.1, 694 input_pertubation: float = 0.1,
696 disc: Optional[ConvNeXtDiscriminator] = None, 695 disc: Optional[ConvNeXtDiscriminator] = None,
696 schedule_sampler: Optional[ScheduleSampler] = None,
697 min_snr_gamma: int = 5, 697 min_snr_gamma: int = 5,
698 avg_loss: AverageMeter = AverageMeter(), 698 avg_loss: AverageMeter = AverageMeter(),
699 avg_acc: AverageMeter = AverageMeter(), 699 avg_acc: AverageMeter = AverageMeter(),
@@ -725,10 +725,14 @@ def train(
725 **kwargs, 725 **kwargs,
726 ) 726 )
727 727
728 if schedule_sampler is None:
729 schedule_sampler = UniformSampler(noise_scheduler.config.num_train_timesteps)
730
728 loss_step_ = partial( 731 loss_step_ = partial(
729 loss_step, 732 loss_step,
730 vae, 733 vae,
731 noise_scheduler, 734 noise_scheduler,
735 schedule_sampler,
732 unet, 736 unet,
733 text_encoder, 737 text_encoder,
734 guidance_scale, 738 guidance_scale,