diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-05 10:19:38 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-05 10:19:38 +0100 |
| commit | 6c64f769043c8212b1a5778e857af691a828798d (patch) | |
| tree | fe4cdf2a4e28e86e31bb7ccd8885c0a42c8632dc /train_dreambooth.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.tar.gz textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.tar.bz2 textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.zip | |
Various cleanups
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 86 |
1 files changed, 21 insertions, 65 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 2e0696b..c658ad6 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -4,9 +4,9 @@ import math | |||
| 4 | import datetime | 4 | import datetime |
| 5 | import logging | 5 | import logging |
| 6 | from pathlib import Path | 6 | from pathlib import Path |
| 7 | from functools import partial | ||
| 7 | 8 | ||
| 8 | import torch | 9 | import torch |
| 9 | import torch.nn.functional as F | ||
| 10 | import torch.utils.checkpoint | 10 | import torch.utils.checkpoint |
| 11 | 11 | ||
| 12 | from accelerate import Accelerator | 12 | from accelerate import Accelerator |
| @@ -20,9 +20,10 @@ from tqdm.auto import tqdm | |||
| 20 | from transformers import CLIPTextModel | 20 | from transformers import CLIPTextModel |
| 21 | from slugify import slugify | 21 | from slugify import slugify |
| 22 | 22 | ||
| 23 | from common import load_config, load_embeddings_from_dir | 23 | from util import load_config, load_embeddings_from_dir |
| 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 25 | from data.csv import CSVDataModule, CSVDataItem | 25 | from data.csv import CSVDataModule, CSVDataItem |
| 26 | from training.common import run_model | ||
| 26 | from training.optimization import get_one_cycle_schedule | 27 | from training.optimization import get_one_cycle_schedule |
| 27 | from training.lr import LRFinder | 28 | from training.lr import LRFinder |
| 28 | from training.util import AverageMeter, CheckpointerBase, save_args | 29 | from training.util import AverageMeter, CheckpointerBase, save_args |
| @@ -610,8 +611,8 @@ def main(): | |||
| 610 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 611 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
| 611 | raise ValueError("--embeddings_dir must point to an existing directory") | 612 | raise ValueError("--embeddings_dir must point to an existing directory") |
| 612 | 613 | ||
| 613 | added_tokens_from_dir = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 614 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
| 614 | print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") | 615 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {zip(added_tokens, added_ids)}") |
| 615 | 616 | ||
| 616 | if len(args.placeholder_token) != 0: | 617 | if len(args.placeholder_token) != 0: |
| 617 | # Convert the initializer_token, placeholder_token to ids | 618 | # Convert the initializer_token, placeholder_token to ids |
| @@ -620,13 +621,15 @@ def main(): | |||
| 620 | for token in args.initializer_token | 621 | for token in args.initializer_token |
| 621 | ] | 622 | ] |
| 622 | 623 | ||
| 623 | new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) | 624 | new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) |
| 624 | embeddings.resize(len(tokenizer)) | 625 | embeddings.resize(len(tokenizer)) |
| 625 | 626 | ||
| 626 | for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): | 627 | init_ratios = [ |
| 627 | embeddings.add_embed(new_token.ids, init_ids) | 628 | embeddings.add_embed(new_id, init_ids) |
| 629 | for (new_id, init_ids) in zip(new_ids, initializer_token_ids) | ||
| 630 | ] | ||
| 628 | 631 | ||
| 629 | print(f"Added {len(new_tokens)} new tokens.") | 632 | print(f"Added {len(new_ids)} new tokens: {zip(args.placeholder_token, new_ids, init_ratios)}") |
| 630 | else: | 633 | else: |
| 631 | placeholder_token_id = [] | 634 | placeholder_token_id = [] |
| 632 | 635 | ||
| @@ -856,63 +859,16 @@ def main(): | |||
| 856 | def on_eval(): | 859 | def on_eval(): |
| 857 | tokenizer.eval() | 860 | tokenizer.eval() |
| 858 | 861 | ||
| 859 | def loop(step: int, batch, eval: bool = False): | 862 | loop = partial( |
| 860 | # Convert images to latent space | 863 | run_model, |
| 861 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 864 | vae=vae, |
| 862 | latents = latents * 0.18215 | 865 | noise_scheduler=noise_scheduler, |
| 863 | 866 | unet=unet, | |
| 864 | # Sample noise that we'll add to the latents | 867 | prompt_processor=prompt_processor, |
| 865 | noise = torch.randn_like(latents) | 868 | num_class_images=args.num_class_images, |
| 866 | bsz = latents.shape[0] | 869 | prior_loss_weight=args.prior_loss_weight, |
| 867 | # Sample a random timestep for each image | 870 | seed=args.seed, |
| 868 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed + step) if eval else None | 871 | ) |
| 869 | timesteps = torch.randint( | ||
| 870 | 0, | ||
| 871 | noise_scheduler.config.num_train_timesteps, | ||
| 872 | (bsz,), | ||
| 873 | generator=timesteps_gen, | ||
| 874 | device=latents.device, | ||
| 875 | ) | ||
| 876 | timesteps = timesteps.long() | ||
| 877 | |||
| 878 | # Add noise to the latents according to the noise magnitude at each timestep | ||
| 879 | # (this is the forward diffusion process) | ||
| 880 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
| 881 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | ||
| 882 | |||
| 883 | # Get the text embedding for conditioning | ||
| 884 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
| 885 | |||
| 886 | # Predict the noise residual | ||
| 887 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
| 888 | |||
| 889 | # Get the target for loss depending on the prediction type | ||
| 890 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 891 | target = noise | ||
| 892 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 893 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 894 | else: | ||
| 895 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 896 | |||
| 897 | if args.num_class_images != 0: | ||
| 898 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
| 899 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
| 900 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
| 901 | |||
| 902 | # Compute instance loss | ||
| 903 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
| 904 | |||
| 905 | # Compute prior loss | ||
| 906 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
| 907 | |||
| 908 | # Add the prior loss to the instance loss. | ||
| 909 | loss = loss + args.prior_loss_weight * prior_loss | ||
| 910 | else: | ||
| 911 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
| 912 | |||
| 913 | acc = (model_pred == target).float().mean() | ||
| 914 | |||
| 915 | return loss, acc, bsz | ||
| 916 | 872 | ||
| 917 | # We need to initialize the trackers we use, and also store our configuration. | 873 | # We need to initialize the trackers we use, and also store our configuration. |
| 918 | # The trackers initializes automatically on the main process. | 874 | # The trackers initializes automatically on the main process. |
