import torch import torch.nn.functional as F from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel def run_model( vae: AutoencoderKL, noise_scheduler: DDPMScheduler, unet: UNet2DConditionModel, prompt_processor, num_class_images: int, prior_loss_weight: float, seed: int, step: int, batch, eval: bool = False ): # Convert images to latent space latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() latents = latents * 0.18215 # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image timesteps_gen = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, (bsz,), generator=timesteps_gen, device=latents.device, ) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noisy_latents = noisy_latents.to(dtype=unet.dtype) # Get the text embedding for conditioning encoder_hidden_states = prompt_processor.get_embeddings( batch["input_ids"], batch["attention_mask"] ) encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) # Predict the noise residual model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") if num_class_images != 0: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") # Compute prior loss prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") # Add the prior loss to the instance loss. loss = loss + prior_loss_weight * prior_loss else: loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") acc = (model_pred == target).float().mean() return loss, acc, bsz