diff options
author | Volpeon <git@volpeon.ink> | 2023-01-05 10:19:38 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-05 10:19:38 +0100 |
commit | 6c64f769043c8212b1a5778e857af691a828798d (patch) | |
tree | fe4cdf2a4e28e86e31bb7ccd8885c0a42c8632dc /training | |
parent | Update (diff) | |
download | textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.tar.gz textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.tar.bz2 textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.zip |
Various cleanups
Diffstat (limited to 'training')
-rw-r--r-- | training/common.py | 75 |
1 files changed, 75 insertions, 0 deletions
diff --git a/training/common.py b/training/common.py new file mode 100644 index 0000000..99a6e67 --- /dev/null +++ b/training/common.py | |||
@@ -0,0 +1,75 @@ | |||
1 | import torch | ||
2 | import torch.nn.functional as F | ||
3 | |||
4 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | ||
5 | |||
6 | |||
7 | def run_model( | ||
8 | vae: AutoencoderKL, | ||
9 | noise_scheduler: DDPMScheduler, | ||
10 | unet: UNet2DConditionModel, | ||
11 | prompt_processor, | ||
12 | num_class_images: int, | ||
13 | prior_loss_weight: float, | ||
14 | seed: int, | ||
15 | step: int, | ||
16 | batch, | ||
17 | eval: bool = False | ||
18 | ): | ||
19 | # Convert images to latent space | ||
20 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | ||
21 | latents = latents * 0.18215 | ||
22 | |||
23 | # Sample noise that we'll add to the latents | ||
24 | noise = torch.randn_like(latents) | ||
25 | bsz = latents.shape[0] | ||
26 | # Sample a random timestep for each image | ||
27 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | ||
28 | timesteps = torch.randint( | ||
29 | 0, | ||
30 | noise_scheduler.config.num_train_timesteps, | ||
31 | (bsz,), | ||
32 | generator=timesteps_gen, | ||
33 | device=latents.device, | ||
34 | ) | ||
35 | timesteps = timesteps.long() | ||
36 | |||
37 | # Add noise to the latents according to the noise magnitude at each timestep | ||
38 | # (this is the forward diffusion process) | ||
39 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
40 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
41 | |||
42 | # Get the text embedding for conditioning | ||
43 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
44 | encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) | ||
45 | |||
46 | # Predict the noise residual | ||
47 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
48 | |||
49 | # Get the target for loss depending on the prediction type | ||
50 | if noise_scheduler.config.prediction_type == "epsilon": | ||
51 | target = noise | ||
52 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
53 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
54 | else: | ||
55 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
56 | |||
57 | if num_class_images != 0: | ||
58 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
59 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
60 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
61 | |||
62 | # Compute instance loss | ||
63 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
64 | |||
65 | # Compute prior loss | ||
66 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
67 | |||
68 | # Add the prior loss to the instance loss. | ||
69 | loss = loss + prior_loss_weight * prior_loss | ||
70 | else: | ||
71 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
72 | |||
73 | acc = (model_pred == target).float().mean() | ||
74 | |||
75 | return loss, acc, bsz | ||