summaryrefslogtreecommitdiffstats
path: root/training/common.py
blob: ab2741a70271f950383ab6ef94ec41431b48bfbe (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import torch.nn.functional as F

from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel


def run_model(
    vae: AutoencoderKL,
    noise_scheduler: DDPMScheduler,
    unet: UNet2DConditionModel,
    prompt_processor,
    num_class_images: int,
    prior_loss_weight: float,
    seed: int,
    step: int,
    batch,
    eval: bool = False
):
    # Convert images to latent space
    latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
    latents = latents * 0.18215

    # Sample noise that we'll add to the latents
    noise = torch.randn_like(latents)
    bsz = latents.shape[0]
    # Sample a random timestep for each image
    timesteps_gen = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None
    timesteps = torch.randint(
        0,
        noise_scheduler.config.num_train_timesteps,
        (bsz,),
        generator=timesteps_gen,
        device=latents.device,
    )
    timesteps = timesteps.long()

    # Add noise to the latents according to the noise magnitude at each timestep
    # (this is the forward diffusion process)
    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
    noisy_latents = noisy_latents.to(dtype=unet.dtype)

    # Get the text embedding for conditioning
    encoder_hidden_states = prompt_processor.get_embeddings(
        batch["input_ids"],
        batch["attention_mask"]
    )
    encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype)

    # Predict the noise residual
    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

    # Get the target for loss depending on the prediction type
    if noise_scheduler.config.prediction_type == "epsilon":
        target = noise
    elif noise_scheduler.config.prediction_type == "v_prediction":
        target = noise_scheduler.get_velocity(latents, noise, timesteps)
    else:
        raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

    if num_class_images != 0:
        # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
        model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
        target, target_prior = torch.chunk(target, 2, dim=0)

        # Compute instance loss
        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

        # Compute prior loss
        prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")

        # Add the prior loss to the instance loss.
        loss = loss + prior_loss_weight * prior_loss
    else:
        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

    acc = (model_pred == target).float().mean()

    return loss, acc, bsz