diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/common.py | 75 |
1 files changed, 75 insertions, 0 deletions
diff --git a/training/common.py b/training/common.py new file mode 100644 index 0000000..99a6e67 --- /dev/null +++ b/training/common.py | |||
| @@ -0,0 +1,75 @@ | |||
| 1 | import torch | ||
| 2 | import torch.nn.functional as F | ||
| 3 | |||
| 4 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | ||
| 5 | |||
| 6 | |||
| 7 | def run_model( | ||
| 8 | vae: AutoencoderKL, | ||
| 9 | noise_scheduler: DDPMScheduler, | ||
| 10 | unet: UNet2DConditionModel, | ||
| 11 | prompt_processor, | ||
| 12 | num_class_images: int, | ||
| 13 | prior_loss_weight: float, | ||
| 14 | seed: int, | ||
| 15 | step: int, | ||
| 16 | batch, | ||
| 17 | eval: bool = False | ||
| 18 | ): | ||
| 19 | # Convert images to latent space | ||
| 20 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | ||
| 21 | latents = latents * 0.18215 | ||
| 22 | |||
| 23 | # Sample noise that we'll add to the latents | ||
| 24 | noise = torch.randn_like(latents) | ||
| 25 | bsz = latents.shape[0] | ||
| 26 | # Sample a random timestep for each image | ||
| 27 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | ||
| 28 | timesteps = torch.randint( | ||
| 29 | 0, | ||
| 30 | noise_scheduler.config.num_train_timesteps, | ||
| 31 | (bsz,), | ||
| 32 | generator=timesteps_gen, | ||
| 33 | device=latents.device, | ||
| 34 | ) | ||
| 35 | timesteps = timesteps.long() | ||
| 36 | |||
| 37 | # Add noise to the latents according to the noise magnitude at each timestep | ||
| 38 | # (this is the forward diffusion process) | ||
| 39 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
| 40 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
| 41 | |||
| 42 | # Get the text embedding for conditioning | ||
| 43 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
| 44 | encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) | ||
| 45 | |||
| 46 | # Predict the noise residual | ||
| 47 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
| 48 | |||
| 49 | # Get the target for loss depending on the prediction type | ||
| 50 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 51 | target = noise | ||
| 52 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 53 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 54 | else: | ||
| 55 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 56 | |||
| 57 | if num_class_images != 0: | ||
| 58 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
| 59 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
| 60 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
| 61 | |||
| 62 | # Compute instance loss | ||
| 63 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
| 64 | |||
| 65 | # Compute prior loss | ||
| 66 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
| 67 | |||
| 68 | # Add the prior loss to the instance loss. | ||
| 69 | loss = loss + prior_loss_weight * prior_loss | ||
| 70 | else: | ||
| 71 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
| 72 | |||
| 73 | acc = (model_pred == target).float().mean() | ||
| 74 | |||
| 75 | return loss, acc, bsz | ||
