diff options
Diffstat (limited to 'training/functional.py')
| -rw-r--r-- | training/functional.py | 20 |
1 files changed, 12 insertions, 8 deletions
diff --git a/training/functional.py b/training/functional.py index e7e1eb3..eae5681 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -25,6 +25,7 @@ from models.clip.util import get_extended_embeddings | |||
| 25 | from models.clip.tokenizer import MultiCLIPTokenizer | 25 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 26 | from models.convnext.discriminator import ConvNeXtDiscriminator | 26 | from models.convnext.discriminator import ConvNeXtDiscriminator |
| 27 | from training.util import AverageMeter | 27 | from training.util import AverageMeter |
| 28 | from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler | ||
| 28 | from util.slerp import slerp | 29 | from util.slerp import slerp |
| 29 | 30 | ||
| 30 | 31 | ||
| @@ -318,6 +319,7 @@ def get_original( | |||
| 318 | def loss_step( | 319 | def loss_step( |
| 319 | vae: AutoencoderKL, | 320 | vae: AutoencoderKL, |
| 320 | noise_scheduler: SchedulerMixin, | 321 | noise_scheduler: SchedulerMixin, |
| 322 | schedule_sampler: ScheduleSampler, | ||
| 321 | unet: UNet2DConditionModel, | 323 | unet: UNet2DConditionModel, |
| 322 | text_encoder: CLIPTextModel, | 324 | text_encoder: CLIPTextModel, |
| 323 | guidance_scale: float, | 325 | guidance_scale: float, |
| @@ -362,14 +364,7 @@ def loss_step( | |||
| 362 | new_noise = noise + input_pertubation * torch.randn_like(noise) | 364 | new_noise = noise + input_pertubation * torch.randn_like(noise) |
| 363 | 365 | ||
| 364 | # Sample a random timestep for each image | 366 | # Sample a random timestep for each image |
| 365 | timesteps = torch.randint( | 367 | timesteps, weights = schedule_sampler.sample(bsz, latents.device) |
| 366 | 0, | ||
| 367 | noise_scheduler.config.num_train_timesteps, | ||
| 368 | (bsz,), | ||
| 369 | generator=generator, | ||
| 370 | device=latents.device, | ||
| 371 | ) | ||
| 372 | timesteps = timesteps.long() | ||
| 373 | 368 | ||
| 374 | # Add noise to the latents according to the noise magnitude at each timestep | 369 | # Add noise to the latents according to the noise magnitude at each timestep |
| 375 | # (this is the forward diffusion process) | 370 | # (this is the forward diffusion process) |
| @@ -443,6 +438,10 @@ def loss_step( | |||
| 443 | ) | 438 | ) |
| 444 | loss = loss * mse_loss_weights | 439 | loss = loss * mse_loss_weights |
| 445 | 440 | ||
| 441 | if isinstance(schedule_sampler, LossAwareSampler): | ||
| 442 | schedule_sampler.update_with_all_losses(timesteps, loss.detach()) | ||
| 443 | |||
| 444 | loss = loss * weights | ||
| 446 | loss = loss.mean() | 445 | loss = loss.mean() |
| 447 | 446 | ||
| 448 | return loss, acc, bsz | 447 | return loss, acc, bsz |
| @@ -694,6 +693,7 @@ def train( | |||
| 694 | offset_noise_strength: float = 0.01, | 693 | offset_noise_strength: float = 0.01, |
| 695 | input_pertubation: float = 0.1, | 694 | input_pertubation: float = 0.1, |
| 696 | disc: Optional[ConvNeXtDiscriminator] = None, | 695 | disc: Optional[ConvNeXtDiscriminator] = None, |
| 696 | schedule_sampler: Optional[ScheduleSampler] = None, | ||
| 697 | min_snr_gamma: int = 5, | 697 | min_snr_gamma: int = 5, |
| 698 | avg_loss: AverageMeter = AverageMeter(), | 698 | avg_loss: AverageMeter = AverageMeter(), |
| 699 | avg_acc: AverageMeter = AverageMeter(), | 699 | avg_acc: AverageMeter = AverageMeter(), |
| @@ -725,10 +725,14 @@ def train( | |||
| 725 | **kwargs, | 725 | **kwargs, |
| 726 | ) | 726 | ) |
| 727 | 727 | ||
| 728 | if schedule_sampler is None: | ||
| 729 | schedule_sampler = UniformSampler(noise_scheduler.config.num_train_timesteps) | ||
| 730 | |||
| 728 | loss_step_ = partial( | 731 | loss_step_ = partial( |
| 729 | loss_step, | 732 | loss_step, |
| 730 | vae, | 733 | vae, |
| 731 | noise_scheduler, | 734 | noise_scheduler, |
| 735 | schedule_sampler, | ||
| 732 | unet, | 736 | unet, |
| 733 | text_encoder, | 737 | text_encoder, |
| 734 | guidance_scale, | 738 | guidance_scale, |
