summaryrefslogtreecommitdiffstats
path: root/training/common.py
blob: 0b2ae44a389de1a1821b2be21792af2ed66698e4 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import torch
import torch.nn.functional as F

from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel

from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion


def generate_class_images(
    accelerator,
    text_encoder,
    vae,
    unet,
    tokenizer,
    scheduler,
    data_train,
    sample_batch_size,
    sample_image_size,
    sample_steps
):
    missing_data = [item for item in data_train if not item.class_image_path.exists()]

    if len(missing_data) != 0:
        batched_data = [
            missing_data[i:i+sample_batch_size]
            for i in range(0, len(missing_data), sample_batch_size)
        ]

        pipeline = VlpnStableDiffusion(
            text_encoder=text_encoder,
            vae=vae,
            unet=unet,
            tokenizer=tokenizer,
            scheduler=scheduler,
        ).to(accelerator.device)
        pipeline.set_progress_bar_config(dynamic_ncols=True)

        with torch.inference_mode():
            for batch in batched_data:
                image_name = [item.class_image_path for item in batch]
                prompt = [item.cprompt for item in batch]
                nprompt = [item.nprompt for item in batch]

                images = pipeline(
                    prompt=prompt,
                    negative_prompt=nprompt,
                    height=sample_image_size,
                    width=sample_image_size,
                    num_inference_steps=sample_steps
                ).images

                for i, image in enumerate(images):
                    image.save(image_name[i])

        del pipeline

        if torch.cuda.is_available():
            torch.cuda.empty_cache()


def loss_step(
    vae: AutoencoderKL,
    noise_scheduler: DDPMScheduler,
    unet: UNet2DConditionModel,
    prompt_processor,
    num_class_images: int,
    prior_loss_weight: float,
    seed: int,
    step: int,
    batch,
    eval: bool = False
):
    # Convert images to latent space
    latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
    latents = latents * 0.18215

    # Sample noise that we'll add to the latents
    noise = torch.randn_like(latents)
    bsz = latents.shape[0]
    # Sample a random timestep for each image
    timesteps_gen = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None
    timesteps = torch.randint(
        0,
        noise_scheduler.config.num_train_timesteps,
        (bsz,),
        generator=timesteps_gen,
        device=latents.device,
    )
    timesteps = timesteps.long()

    # Add noise to the latents according to the noise magnitude at each timestep
    # (this is the forward diffusion process)
    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
    noisy_latents = noisy_latents.to(dtype=unet.dtype)

    # Get the text embedding for conditioning
    encoder_hidden_states = prompt_processor.get_embeddings(
        batch["input_ids"],
        batch["attention_mask"]
    )
    encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype)

    # Predict the noise residual
    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

    # Get the target for loss depending on the prediction type
    if noise_scheduler.config.prediction_type == "epsilon":
        target = noise
    elif noise_scheduler.config.prediction_type == "v_prediction":
        target = noise_scheduler.get_velocity(latents, noise, timesteps)
    else:
        raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

    if num_class_images != 0:
        # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
        model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
        target, target_prior = torch.chunk(target, 2, dim=0)

        # Compute instance loss
        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

        # Compute prior loss
        prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")

        # Add the prior loss to the instance loss.
        loss = loss + prior_loss_weight * prior_loss
    else:
        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

    acc = (model_pred == target).float().mean()

    return loss, acc, bsz