diff options
author | Volpeon <git@volpeon.ink> | 2023-01-05 10:19:38 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-05 10:19:38 +0100 |
commit | 6c64f769043c8212b1a5778e857af691a828798d (patch) | |
tree | fe4cdf2a4e28e86e31bb7ccd8885c0a42c8632dc /train_dreambooth.py | |
parent | Update (diff) | |
download | textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.tar.gz textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.tar.bz2 textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.zip |
Various cleanups
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 86 |
1 files changed, 21 insertions, 65 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 2e0696b..c658ad6 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -4,9 +4,9 @@ import math | |||
4 | import datetime | 4 | import datetime |
5 | import logging | 5 | import logging |
6 | from pathlib import Path | 6 | from pathlib import Path |
7 | from functools import partial | ||
7 | 8 | ||
8 | import torch | 9 | import torch |
9 | import torch.nn.functional as F | ||
10 | import torch.utils.checkpoint | 10 | import torch.utils.checkpoint |
11 | 11 | ||
12 | from accelerate import Accelerator | 12 | from accelerate import Accelerator |
@@ -20,9 +20,10 @@ from tqdm.auto import tqdm | |||
20 | from transformers import CLIPTextModel | 20 | from transformers import CLIPTextModel |
21 | from slugify import slugify | 21 | from slugify import slugify |
22 | 22 | ||
23 | from common import load_config, load_embeddings_from_dir | 23 | from util import load_config, load_embeddings_from_dir |
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
25 | from data.csv import CSVDataModule, CSVDataItem | 25 | from data.csv import CSVDataModule, CSVDataItem |
26 | from training.common import run_model | ||
26 | from training.optimization import get_one_cycle_schedule | 27 | from training.optimization import get_one_cycle_schedule |
27 | from training.lr import LRFinder | 28 | from training.lr import LRFinder |
28 | from training.util import AverageMeter, CheckpointerBase, save_args | 29 | from training.util import AverageMeter, CheckpointerBase, save_args |
@@ -610,8 +611,8 @@ def main(): | |||
610 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 611 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
611 | raise ValueError("--embeddings_dir must point to an existing directory") | 612 | raise ValueError("--embeddings_dir must point to an existing directory") |
612 | 613 | ||
613 | added_tokens_from_dir = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 614 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
614 | print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") | 615 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {zip(added_tokens, added_ids)}") |
615 | 616 | ||
616 | if len(args.placeholder_token) != 0: | 617 | if len(args.placeholder_token) != 0: |
617 | # Convert the initializer_token, placeholder_token to ids | 618 | # Convert the initializer_token, placeholder_token to ids |
@@ -620,13 +621,15 @@ def main(): | |||
620 | for token in args.initializer_token | 621 | for token in args.initializer_token |
621 | ] | 622 | ] |
622 | 623 | ||
623 | new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) | 624 | new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) |
624 | embeddings.resize(len(tokenizer)) | 625 | embeddings.resize(len(tokenizer)) |
625 | 626 | ||
626 | for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): | 627 | init_ratios = [ |
627 | embeddings.add_embed(new_token.ids, init_ids) | 628 | embeddings.add_embed(new_id, init_ids) |
629 | for (new_id, init_ids) in zip(new_ids, initializer_token_ids) | ||
630 | ] | ||
628 | 631 | ||
629 | print(f"Added {len(new_tokens)} new tokens.") | 632 | print(f"Added {len(new_ids)} new tokens: {zip(args.placeholder_token, new_ids, init_ratios)}") |
630 | else: | 633 | else: |
631 | placeholder_token_id = [] | 634 | placeholder_token_id = [] |
632 | 635 | ||
@@ -856,63 +859,16 @@ def main(): | |||
856 | def on_eval(): | 859 | def on_eval(): |
857 | tokenizer.eval() | 860 | tokenizer.eval() |
858 | 861 | ||
859 | def loop(step: int, batch, eval: bool = False): | 862 | loop = partial( |
860 | # Convert images to latent space | 863 | run_model, |
861 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 864 | vae=vae, |
862 | latents = latents * 0.18215 | 865 | noise_scheduler=noise_scheduler, |
863 | 866 | unet=unet, | |
864 | # Sample noise that we'll add to the latents | 867 | prompt_processor=prompt_processor, |
865 | noise = torch.randn_like(latents) | 868 | num_class_images=args.num_class_images, |
866 | bsz = latents.shape[0] | 869 | prior_loss_weight=args.prior_loss_weight, |
867 | # Sample a random timestep for each image | 870 | seed=args.seed, |
868 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed + step) if eval else None | 871 | ) |
869 | timesteps = torch.randint( | ||
870 | 0, | ||
871 | noise_scheduler.config.num_train_timesteps, | ||
872 | (bsz,), | ||
873 | generator=timesteps_gen, | ||
874 | device=latents.device, | ||
875 | ) | ||
876 | timesteps = timesteps.long() | ||
877 | |||
878 | # Add noise to the latents according to the noise magnitude at each timestep | ||
879 | # (this is the forward diffusion process) | ||
880 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
881 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
882 | |||
883 | # Get the text embedding for conditioning | ||
884 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
885 | |||
886 | # Predict the noise residual | ||
887 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
888 | |||
889 | # Get the target for loss depending on the prediction type | ||
890 | if noise_scheduler.config.prediction_type == "epsilon": | ||
891 | target = noise | ||
892 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
893 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
894 | else: | ||
895 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
896 | |||
897 | if args.num_class_images != 0: | ||
898 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
899 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
900 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
901 | |||
902 | # Compute instance loss | ||
903 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
904 | |||
905 | # Compute prior loss | ||
906 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
907 | |||
908 | # Add the prior loss to the instance loss. | ||
909 | loss = loss + args.prior_loss_weight * prior_loss | ||
910 | else: | ||
911 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
912 | |||
913 | acc = (model_pred == target).float().mean() | ||
914 | |||
915 | return loss, acc, bsz | ||
916 | 872 | ||
917 | # We need to initialize the trackers we use, and also store our configuration. | 873 | # We need to initialize the trackers we use, and also store our configuration. |
918 | # The trackers initializes automatically on the main process. | 874 | # The trackers initializes automatically on the main process. |