summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-05 10:19:38 +0100
committerVolpeon <git@volpeon.ink>2023-01-05 10:19:38 +0100
commit6c64f769043c8212b1a5778e857af691a828798d (patch)
treefe4cdf2a4e28e86e31bb7ccd8885c0a42c8632dc /training
parentUpdate (diff)
downloadtextual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.tar.gz
textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.tar.bz2
textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.zip
Various cleanups
Diffstat (limited to 'training')
-rw-r--r--training/common.py75
1 files changed, 75 insertions, 0 deletions
diff --git a/training/common.py b/training/common.py
new file mode 100644
index 0000000..99a6e67
--- /dev/null
+++ b/training/common.py
@@ -0,0 +1,75 @@
1import torch
2import torch.nn.functional as F
3
4from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
5
6
7def run_model(
8 vae: AutoencoderKL,
9 noise_scheduler: DDPMScheduler,
10 unet: UNet2DConditionModel,
11 prompt_processor,
12 num_class_images: int,
13 prior_loss_weight: float,
14 seed: int,
15 step: int,
16 batch,
17 eval: bool = False
18):
19 # Convert images to latent space
20 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
21 latents = latents * 0.18215
22
23 # Sample noise that we'll add to the latents
24 noise = torch.randn_like(latents)
25 bsz = latents.shape[0]
26 # Sample a random timestep for each image
27 timesteps_gen = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None
28 timesteps = torch.randint(
29 0,
30 noise_scheduler.config.num_train_timesteps,
31 (bsz,),
32 generator=timesteps_gen,
33 device=latents.device,
34 )
35 timesteps = timesteps.long()
36
37 # Add noise to the latents according to the noise magnitude at each timestep
38 # (this is the forward diffusion process)
39 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
40 noisy_latents = noisy_latents.to(dtype=unet.dtype)
41
42 # Get the text embedding for conditioning
43 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
44 encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype)
45
46 # Predict the noise residual
47 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
48
49 # Get the target for loss depending on the prediction type
50 if noise_scheduler.config.prediction_type == "epsilon":
51 target = noise
52 elif noise_scheduler.config.prediction_type == "v_prediction":
53 target = noise_scheduler.get_velocity(latents, noise, timesteps)
54 else:
55 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
56
57 if num_class_images != 0:
58 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
59 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
60 target, target_prior = torch.chunk(target, 2, dim=0)
61
62 # Compute instance loss
63 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
64
65 # Compute prior loss
66 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
67
68 # Add the prior loss to the instance loss.
69 loss = loss + prior_loss_weight * prior_loss
70 else:
71 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
72
73 acc = (model_pred == target).float().mean()
74
75 return loss, acc, bsz