diff options
author | Volpeon <git@volpeon.ink> | 2023-01-05 10:19:38 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-05 10:19:38 +0100 |
commit | 6c64f769043c8212b1a5778e857af691a828798d (patch) | |
tree | fe4cdf2a4e28e86e31bb7ccd8885c0a42c8632dc /train_ti.py | |
parent | Update (diff) | |
download | textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.tar.gz textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.tar.bz2 textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.zip |
Various cleanups
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 86 |
1 files changed, 21 insertions, 65 deletions
diff --git a/train_ti.py b/train_ti.py index 8ada98c..5df6850 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -3,9 +3,9 @@ import math | |||
3 | import datetime | 3 | import datetime |
4 | import logging | 4 | import logging |
5 | from pathlib import Path | 5 | from pathlib import Path |
6 | from functools import partial | ||
6 | 7 | ||
7 | import torch | 8 | import torch |
8 | import torch.nn.functional as F | ||
9 | import torch.utils.checkpoint | 9 | import torch.utils.checkpoint |
10 | 10 | ||
11 | from accelerate import Accelerator | 11 | from accelerate import Accelerator |
@@ -18,9 +18,10 @@ from tqdm.auto import tqdm | |||
18 | from transformers import CLIPTextModel | 18 | from transformers import CLIPTextModel |
19 | from slugify import slugify | 19 | from slugify import slugify |
20 | 20 | ||
21 | from common import load_config, load_embeddings_from_dir | 21 | from util import load_config, load_embeddings_from_dir |
22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
23 | from data.csv import CSVDataModule, CSVDataItem | 23 | from data.csv import CSVDataModule, CSVDataItem |
24 | from training.common import run_model | ||
24 | from training.optimization import get_one_cycle_schedule | 25 | from training.optimization import get_one_cycle_schedule |
25 | from training.lr import LRFinder | 26 | from training.lr import LRFinder |
26 | from training.util import AverageMeter, CheckpointerBase, save_args | 27 | from training.util import AverageMeter, CheckpointerBase, save_args |
@@ -570,8 +571,8 @@ def main(): | |||
570 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 571 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
571 | raise ValueError("--embeddings_dir must point to an existing directory") | 572 | raise ValueError("--embeddings_dir must point to an existing directory") |
572 | 573 | ||
573 | added_tokens_from_dir = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 574 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
574 | print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") | 575 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {zip(added_tokens, added_ids)}") |
575 | 576 | ||
576 | # Convert the initializer_token, placeholder_token to ids | 577 | # Convert the initializer_token, placeholder_token to ids |
577 | initializer_token_ids = [ | 578 | initializer_token_ids = [ |
@@ -579,13 +580,15 @@ def main(): | |||
579 | for token in args.initializer_token | 580 | for token in args.initializer_token |
580 | ] | 581 | ] |
581 | 582 | ||
582 | new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) | 583 | new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) |
583 | embeddings.resize(len(tokenizer)) | 584 | embeddings.resize(len(tokenizer)) |
584 | 585 | ||
585 | for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): | 586 | init_ratios = [ |
586 | embeddings.add_embed(new_token.ids, init_ids) | 587 | embeddings.add_embed(new_id, init_ids) |
588 | for (new_id, init_ids) in zip(new_ids, initializer_token_ids) | ||
589 | ] | ||
587 | 590 | ||
588 | print(f"Added {len(new_tokens)} new tokens.") | 591 | print(f"Added {len(new_ids)} new tokens: {zip(args.placeholder_token, new_ids, init_ratios)}") |
589 | 592 | ||
590 | vae.requires_grad_(False) | 593 | vae.requires_grad_(False) |
591 | unet.requires_grad_(False) | 594 | unet.requires_grad_(False) |
@@ -807,63 +810,16 @@ def main(): | |||
807 | def on_eval(): | 810 | def on_eval(): |
808 | tokenizer.eval() | 811 | tokenizer.eval() |
809 | 812 | ||
810 | def loop(step: int, batch, eval: bool = False): | 813 | loop = partial( |
811 | # Convert images to latent space | 814 | run_model, |
812 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | 815 | vae=vae, |
813 | latents = latents * 0.18215 | 816 | noise_scheduler=noise_scheduler, |
814 | 817 | unet=unet, | |
815 | # Sample noise that we'll add to the latents | 818 | prompt_processor=prompt_processor, |
816 | noise = torch.randn_like(latents) | 819 | num_class_images=args.num_class_images, |
817 | bsz = latents.shape[0] | 820 | prior_loss_weight=args.prior_loss_weight, |
818 | # Sample a random timestep for each image | 821 | seed=args.seed, |
819 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed + step) if eval else None | 822 | ) |
820 | timesteps = torch.randint( | ||
821 | 0, | ||
822 | noise_scheduler.config.num_train_timesteps, | ||
823 | (bsz,), | ||
824 | generator=timesteps_gen, | ||
825 | device=latents.device, | ||
826 | ) | ||
827 | timesteps = timesteps.long() | ||
828 | |||
829 | # Add noise to the latents according to the noise magnitude at each timestep | ||
830 | # (this is the forward diffusion process) | ||
831 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
832 | |||
833 | # Get the text embedding for conditioning | ||
834 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
835 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | ||
836 | |||
837 | # Predict the noise residual | ||
838 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
839 | |||
840 | # Get the target for loss depending on the prediction type | ||
841 | if noise_scheduler.config.prediction_type == "epsilon": | ||
842 | target = noise | ||
843 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
844 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
845 | else: | ||
846 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
847 | |||
848 | if args.num_class_images != 0: | ||
849 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
850 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
851 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
852 | |||
853 | # Compute instance loss | ||
854 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
855 | |||
856 | # Compute prior loss | ||
857 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
858 | |||
859 | # Add the prior loss to the instance loss. | ||
860 | loss = loss + args.prior_loss_weight * prior_loss | ||
861 | else: | ||
862 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
863 | |||
864 | acc = (model_pred == target).float().mean() | ||
865 | |||
866 | return loss, acc, bsz | ||
867 | 823 | ||
868 | # We need to initialize the trackers we use, and also store our configuration. | 824 | # We need to initialize the trackers we use, and also store our configuration. |
869 | # The trackers initializes automatically on the main process. | 825 | # The trackers initializes automatically on the main process. |