summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-05 10:19:38 +0100
committerVolpeon <git@volpeon.ink>2023-01-05 10:19:38 +0100
commit6c64f769043c8212b1a5778e857af691a828798d (patch)
treefe4cdf2a4e28e86e31bb7ccd8885c0a42c8632dc /train_dreambooth.py
parentUpdate (diff)
downloadtextual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.tar.gz
textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.tar.bz2
textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.zip
Various cleanups
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py86
1 files changed, 21 insertions, 65 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 2e0696b..c658ad6 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -4,9 +4,9 @@ import math
4import datetime 4import datetime
5import logging 5import logging
6from pathlib import Path 6from pathlib import Path
7from functools import partial
7 8
8import torch 9import torch
9import torch.nn.functional as F
10import torch.utils.checkpoint 10import torch.utils.checkpoint
11 11
12from accelerate import Accelerator 12from accelerate import Accelerator
@@ -20,9 +20,10 @@ from tqdm.auto import tqdm
20from transformers import CLIPTextModel 20from transformers import CLIPTextModel
21from slugify import slugify 21from slugify import slugify
22 22
23from common import load_config, load_embeddings_from_dir 23from util import load_config, load_embeddings_from_dir
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import CSVDataModule, CSVDataItem 25from data.csv import CSVDataModule, CSVDataItem
26from training.common import run_model
26from training.optimization import get_one_cycle_schedule 27from training.optimization import get_one_cycle_schedule
27from training.lr import LRFinder 28from training.lr import LRFinder
28from training.util import AverageMeter, CheckpointerBase, save_args 29from training.util import AverageMeter, CheckpointerBase, save_args
@@ -610,8 +611,8 @@ def main():
610 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 611 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
611 raise ValueError("--embeddings_dir must point to an existing directory") 612 raise ValueError("--embeddings_dir must point to an existing directory")
612 613
613 added_tokens_from_dir = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 614 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
614 print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") 615 print(f"Added {len(added_tokens)} tokens from embeddings dir: {zip(added_tokens, added_ids)}")
615 616
616 if len(args.placeholder_token) != 0: 617 if len(args.placeholder_token) != 0:
617 # Convert the initializer_token, placeholder_token to ids 618 # Convert the initializer_token, placeholder_token to ids
@@ -620,13 +621,15 @@ def main():
620 for token in args.initializer_token 621 for token in args.initializer_token
621 ] 622 ]
622 623
623 new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) 624 new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors)
624 embeddings.resize(len(tokenizer)) 625 embeddings.resize(len(tokenizer))
625 626
626 for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): 627 init_ratios = [
627 embeddings.add_embed(new_token.ids, init_ids) 628 embeddings.add_embed(new_id, init_ids)
629 for (new_id, init_ids) in zip(new_ids, initializer_token_ids)
630 ]
628 631
629 print(f"Added {len(new_tokens)} new tokens.") 632 print(f"Added {len(new_ids)} new tokens: {zip(args.placeholder_token, new_ids, init_ratios)}")
630 else: 633 else:
631 placeholder_token_id = [] 634 placeholder_token_id = []
632 635
@@ -856,63 +859,16 @@ def main():
856 def on_eval(): 859 def on_eval():
857 tokenizer.eval() 860 tokenizer.eval()
858 861
859 def loop(step: int, batch, eval: bool = False): 862 loop = partial(
860 # Convert images to latent space 863 run_model,
861 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 864 vae=vae,
862 latents = latents * 0.18215 865 noise_scheduler=noise_scheduler,
863 866 unet=unet,
864 # Sample noise that we'll add to the latents 867 prompt_processor=prompt_processor,
865 noise = torch.randn_like(latents) 868 num_class_images=args.num_class_images,
866 bsz = latents.shape[0] 869 prior_loss_weight=args.prior_loss_weight,
867 # Sample a random timestep for each image 870 seed=args.seed,
868 timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed + step) if eval else None 871 )
869 timesteps = torch.randint(
870 0,
871 noise_scheduler.config.num_train_timesteps,
872 (bsz,),
873 generator=timesteps_gen,
874 device=latents.device,
875 )
876 timesteps = timesteps.long()
877
878 # Add noise to the latents according to the noise magnitude at each timestep
879 # (this is the forward diffusion process)
880 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
881 noisy_latents = noisy_latents.to(dtype=unet.dtype)
882
883 # Get the text embedding for conditioning
884 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
885
886 # Predict the noise residual
887 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
888
889 # Get the target for loss depending on the prediction type
890 if noise_scheduler.config.prediction_type == "epsilon":
891 target = noise
892 elif noise_scheduler.config.prediction_type == "v_prediction":
893 target = noise_scheduler.get_velocity(latents, noise, timesteps)
894 else:
895 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
896
897 if args.num_class_images != 0:
898 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
899 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
900 target, target_prior = torch.chunk(target, 2, dim=0)
901
902 # Compute instance loss
903 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
904
905 # Compute prior loss
906 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
907
908 # Add the prior loss to the instance loss.
909 loss = loss + args.prior_loss_weight * prior_loss
910 else:
911 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
912
913 acc = (model_pred == target).float().mean()
914
915 return loss, acc, bsz
916 872
917 # We need to initialize the trackers we use, and also store our configuration. 873 # We need to initialize the trackers we use, and also store our configuration.
918 # The trackers initializes automatically on the main process. 874 # The trackers initializes automatically on the main process.