summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-05 10:19:38 +0100
committerVolpeon <git@volpeon.ink>2023-01-05 10:19:38 +0100
commit6c64f769043c8212b1a5778e857af691a828798d (patch)
treefe4cdf2a4e28e86e31bb7ccd8885c0a42c8632dc /train_ti.py
parentUpdate (diff)
downloadtextual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.tar.gz
textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.tar.bz2
textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.zip
Various cleanups
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py86
1 files changed, 21 insertions, 65 deletions
diff --git a/train_ti.py b/train_ti.py
index 8ada98c..5df6850 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -3,9 +3,9 @@ import math
3import datetime 3import datetime
4import logging 4import logging
5from pathlib import Path 5from pathlib import Path
6from functools import partial
6 7
7import torch 8import torch
8import torch.nn.functional as F
9import torch.utils.checkpoint 9import torch.utils.checkpoint
10 10
11from accelerate import Accelerator 11from accelerate import Accelerator
@@ -18,9 +18,10 @@ from tqdm.auto import tqdm
18from transformers import CLIPTextModel 18from transformers import CLIPTextModel
19from slugify import slugify 19from slugify import slugify
20 20
21from common import load_config, load_embeddings_from_dir 21from util import load_config, load_embeddings_from_dir
22from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 22from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
23from data.csv import CSVDataModule, CSVDataItem 23from data.csv import CSVDataModule, CSVDataItem
24from training.common import run_model
24from training.optimization import get_one_cycle_schedule 25from training.optimization import get_one_cycle_schedule
25from training.lr import LRFinder 26from training.lr import LRFinder
26from training.util import AverageMeter, CheckpointerBase, save_args 27from training.util import AverageMeter, CheckpointerBase, save_args
@@ -570,8 +571,8 @@ def main():
570 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 571 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
571 raise ValueError("--embeddings_dir must point to an existing directory") 572 raise ValueError("--embeddings_dir must point to an existing directory")
572 573
573 added_tokens_from_dir = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 574 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
574 print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") 575 print(f"Added {len(added_tokens)} tokens from embeddings dir: {zip(added_tokens, added_ids)}")
575 576
576 # Convert the initializer_token, placeholder_token to ids 577 # Convert the initializer_token, placeholder_token to ids
577 initializer_token_ids = [ 578 initializer_token_ids = [
@@ -579,13 +580,15 @@ def main():
579 for token in args.initializer_token 580 for token in args.initializer_token
580 ] 581 ]
581 582
582 new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) 583 new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors)
583 embeddings.resize(len(tokenizer)) 584 embeddings.resize(len(tokenizer))
584 585
585 for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): 586 init_ratios = [
586 embeddings.add_embed(new_token.ids, init_ids) 587 embeddings.add_embed(new_id, init_ids)
588 for (new_id, init_ids) in zip(new_ids, initializer_token_ids)
589 ]
587 590
588 print(f"Added {len(new_tokens)} new tokens.") 591 print(f"Added {len(new_ids)} new tokens: {zip(args.placeholder_token, new_ids, init_ratios)}")
589 592
590 vae.requires_grad_(False) 593 vae.requires_grad_(False)
591 unet.requires_grad_(False) 594 unet.requires_grad_(False)
@@ -807,63 +810,16 @@ def main():
807 def on_eval(): 810 def on_eval():
808 tokenizer.eval() 811 tokenizer.eval()
809 812
810 def loop(step: int, batch, eval: bool = False): 813 loop = partial(
811 # Convert images to latent space 814 run_model,
812 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() 815 vae=vae,
813 latents = latents * 0.18215 816 noise_scheduler=noise_scheduler,
814 817 unet=unet,
815 # Sample noise that we'll add to the latents 818 prompt_processor=prompt_processor,
816 noise = torch.randn_like(latents) 819 num_class_images=args.num_class_images,
817 bsz = latents.shape[0] 820 prior_loss_weight=args.prior_loss_weight,
818 # Sample a random timestep for each image 821 seed=args.seed,
819 timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed + step) if eval else None 822 )
820 timesteps = torch.randint(
821 0,
822 noise_scheduler.config.num_train_timesteps,
823 (bsz,),
824 generator=timesteps_gen,
825 device=latents.device,
826 )
827 timesteps = timesteps.long()
828
829 # Add noise to the latents according to the noise magnitude at each timestep
830 # (this is the forward diffusion process)
831 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
832
833 # Get the text embedding for conditioning
834 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
835 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)
836
837 # Predict the noise residual
838 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
839
840 # Get the target for loss depending on the prediction type
841 if noise_scheduler.config.prediction_type == "epsilon":
842 target = noise
843 elif noise_scheduler.config.prediction_type == "v_prediction":
844 target = noise_scheduler.get_velocity(latents, noise, timesteps)
845 else:
846 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
847
848 if args.num_class_images != 0:
849 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
850 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
851 target, target_prior = torch.chunk(target, 2, dim=0)
852
853 # Compute instance loss
854 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
855
856 # Compute prior loss
857 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
858
859 # Add the prior loss to the instance loss.
860 loss = loss + args.prior_loss_weight * prior_loss
861 else:
862 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
863
864 acc = (model_pred == target).float().mean()
865
866 return loss, acc, bsz
867 823
868 # We need to initialize the trackers we use, and also store our configuration. 824 # We need to initialize the trackers we use, and also store our configuration.
869 # The trackers initializes automatically on the main process. 825 # The trackers initializes automatically on the main process.