From 6b8a93f46f053668c8023520225a18445d48d8f1 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 25 Mar 2023 16:34:48 +0100 Subject: Update --- training/functional.py | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 87bb339..d285366 100644 --- a/training/functional.py +++ b/training/functional.py @@ -274,7 +274,7 @@ def loss_step( noise_scheduler: SchedulerMixin, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, - with_prior_preservation: bool, + guidance_scale: float, prior_loss_weight: float, seed: int, offset_noise_strength: float, @@ -283,13 +283,13 @@ def loss_step( eval: bool = False, min_snr_gamma: int = 5, ): - # Convert images to latent space - latents = vae.encode(batch["pixel_values"]).latent_dist.sample() - latents = latents * vae.config.scaling_factor - - bsz = latents.shape[0] + images = batch["pixel_values"] + generator = torch.Generator(device=images.device).manual_seed(seed + step) if eval else None + bsz = images.shape[0] - generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None + # Convert images to latent space + latents = vae.encode(images).latent_dist.sample(generator=generator) + latents *= vae.config.scaling_factor # Sample noise that we'll add to the latents noise = torch.randn( @@ -301,13 +301,13 @@ def loss_step( ) if offset_noise_strength != 0: - noise += offset_noise_strength * perlin_noise( - latents.shape, - res=1, + offset_noise = torch.randn( + (latents.shape[0], latents.shape[1], 1, 1), dtype=latents.dtype, device=latents.device, generator=generator - ) + ).expand(noise.shape) + noise += offset_noise_strength * offset_noise # Sample a random timestep for each image timesteps = torch.randint( @@ -343,7 +343,13 @@ def loss_step( else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - if with_prior_preservation: + if guidance_scale != 0: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred_uncond, model_pred_text = torch.chunk(model_pred, 2, dim=0) + model_pred = model_pred_uncond + guidance_scale * (model_pred_text - model_pred_uncond) + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + elif prior_loss_weight != 0: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) @@ -607,9 +613,9 @@ def train( checkpoint_frequency: int = 50, milestone_checkpoints: bool = True, global_step_offset: int = 0, - with_prior_preservation: bool = False, + guidance_scale: float = 0.0, prior_loss_weight: float = 1.0, - offset_noise_strength: float = 0.1, + offset_noise_strength: float = 0.15, **kwargs, ): text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( @@ -638,7 +644,7 @@ def train( noise_scheduler, unet, text_encoder, - with_prior_preservation, + guidance_scale, prior_loss_weight, seed, offset_noise_strength, -- cgit v1.2.3-70-g09d2