From 8d2aa65402c829583e26cdf2c336b8d3057657d6 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 5 May 2023 10:51:14 +0200 Subject: Update --- training/functional.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 38dd59f..e7e1eb3 100644 --- a/training/functional.py +++ b/training/functional.py @@ -324,6 +324,7 @@ def loss_step( prior_loss_weight: float, seed: int, offset_noise_strength: float, + input_pertubation: float, disc: Optional[ConvNeXtDiscriminator], min_snr_gamma: int, step: int, @@ -337,7 +338,7 @@ def loss_step( # Convert images to latent space latents = vae.encode(images).latent_dist.sample(generator=generator) - latents *= vae.config.scaling_factor + latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents noise = torch.randn( @@ -355,7 +356,10 @@ def loss_step( device=latents.device, generator=generator ).expand(noise.shape) - noise += offset_noise_strength * offset_noise + noise = noise + offset_noise_strength * offset_noise + + if input_pertubation != 0: + new_noise = noise + input_pertubation * torch.randn_like(noise) # Sample a random timestep for each image timesteps = torch.randint( @@ -369,7 +373,10 @@ def loss_step( # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + if input_pertubation != 0: + noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps) + else: + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noisy_latents = noisy_latents.to(dtype=unet.dtype) # Get the text embedding for conditioning @@ -381,7 +388,7 @@ def loss_step( encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] if guidance_scale != 0: uncond_encoder_hidden_states = get_extended_embeddings( @@ -391,7 +398,7 @@ def loss_step( ) uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) - model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states).sample + model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False)[0] model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) # Get the target for loss depending on the prediction type @@ -424,9 +431,9 @@ def loss_step( if disc is not None: rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps) - rec_latent /= vae.config.scaling_factor + rec_latent = rec_latent / vae.config.scaling_factor rec_latent = rec_latent.to(dtype=vae.dtype) - rec = vae.decode(rec_latent).sample + rec = vae.decode(rec_latent, return_dict=False)[0] loss = 1 - disc.get_score(rec) if min_snr_gamma != 0: @@ -434,7 +441,7 @@ def loss_step( mse_loss_weights = ( torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) - loss *= mse_loss_weights + loss = loss * mse_loss_weights loss = loss.mean() @@ -539,7 +546,7 @@ def train_loop( with on_train(cycle): for step, batch in enumerate(train_dataloader): loss, acc, bsz = loss_step(step, batch, cache) - loss /= gradient_accumulation_steps + loss = loss / gradient_accumulation_steps accelerator.backward(loss) @@ -598,7 +605,7 @@ def train_loop( with torch.inference_mode(), on_eval(): for step, batch in enumerate(val_dataloader): loss, acc, bsz = loss_step(step, batch, cache, True) - loss /= gradient_accumulation_steps + loss = loss / gradient_accumulation_steps cur_loss_val.update(loss.item(), bsz) cur_acc_val.update(acc.item(), bsz) @@ -684,7 +691,8 @@ def train( global_step_offset: int = 0, guidance_scale: float = 0.0, prior_loss_weight: float = 1.0, - offset_noise_strength: float = 0.15, + offset_noise_strength: float = 0.01, + input_pertubation: float = 0.1, disc: Optional[ConvNeXtDiscriminator] = None, min_snr_gamma: int = 5, avg_loss: AverageMeter = AverageMeter(), @@ -704,7 +712,7 @@ def train( if compile_unet: unet = torch.compile(unet, backend='hidet') - # unet = torch.compile(unet) + # unet = torch.compile(unet, mode="reduce-overhead") callbacks = strategy.callbacks( accelerator=accelerator, @@ -727,6 +735,7 @@ def train( prior_loss_weight, seed, offset_noise_strength, + input_pertubation, disc, min_snr_gamma, ) -- cgit v1.2.3-70-g09d2