From 6c64f769043c8212b1a5778e857af691a828798d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 5 Jan 2023 10:19:38 +0100 Subject: Various cleanups --- training/common.py | 75 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 training/common.py (limited to 'training') diff --git a/training/common.py b/training/common.py new file mode 100644 index 0000000..99a6e67 --- /dev/null +++ b/training/common.py @@ -0,0 +1,75 @@ +import torch +import torch.nn.functional as F + +from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel + + +def run_model( + vae: AutoencoderKL, + noise_scheduler: DDPMScheduler, + unet: UNet2DConditionModel, + prompt_processor, + num_class_images: int, + prior_loss_weight: float, + seed: int, + step: int, + batch, + eval: bool = False +): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps_gen = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None + timesteps = torch.randint( + 0, + noise_scheduler.config.num_train_timesteps, + (bsz,), + generator=timesteps_gen, + device=latents.device, + ) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noisy_latents.to(dtype=unet.dtype) + + # Get the text embedding for conditioning + encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) + encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if num_class_images != 0: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + acc = (model_pred == target).float().mean() + + return loss, acc, bsz -- cgit v1.2.3-54-g00ecf