diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-05 10:19:38 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-05 10:19:38 +0100 |
| commit | 6c64f769043c8212b1a5778e857af691a828798d (patch) | |
| tree | fe4cdf2a4e28e86e31bb7ccd8885c0a42c8632dc /train_ti.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.tar.gz textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.tar.bz2 textual-inversion-diff-6c64f769043c8212b1a5778e857af691a828798d.zip | |
Various cleanups
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 86 |
1 files changed, 21 insertions, 65 deletions
diff --git a/train_ti.py b/train_ti.py index 8ada98c..5df6850 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -3,9 +3,9 @@ import math | |||
| 3 | import datetime | 3 | import datetime |
| 4 | import logging | 4 | import logging |
| 5 | from pathlib import Path | 5 | from pathlib import Path |
| 6 | from functools import partial | ||
| 6 | 7 | ||
| 7 | import torch | 8 | import torch |
| 8 | import torch.nn.functional as F | ||
| 9 | import torch.utils.checkpoint | 9 | import torch.utils.checkpoint |
| 10 | 10 | ||
| 11 | from accelerate import Accelerator | 11 | from accelerate import Accelerator |
| @@ -18,9 +18,10 @@ from tqdm.auto import tqdm | |||
| 18 | from transformers import CLIPTextModel | 18 | from transformers import CLIPTextModel |
| 19 | from slugify import slugify | 19 | from slugify import slugify |
| 20 | 20 | ||
| 21 | from common import load_config, load_embeddings_from_dir | 21 | from util import load_config, load_embeddings_from_dir |
| 22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 23 | from data.csv import CSVDataModule, CSVDataItem | 23 | from data.csv import CSVDataModule, CSVDataItem |
| 24 | from training.common import run_model | ||
| 24 | from training.optimization import get_one_cycle_schedule | 25 | from training.optimization import get_one_cycle_schedule |
| 25 | from training.lr import LRFinder | 26 | from training.lr import LRFinder |
| 26 | from training.util import AverageMeter, CheckpointerBase, save_args | 27 | from training.util import AverageMeter, CheckpointerBase, save_args |
| @@ -570,8 +571,8 @@ def main(): | |||
| 570 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 571 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
| 571 | raise ValueError("--embeddings_dir must point to an existing directory") | 572 | raise ValueError("--embeddings_dir must point to an existing directory") |
| 572 | 573 | ||
| 573 | added_tokens_from_dir = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 574 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
| 574 | print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") | 575 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {zip(added_tokens, added_ids)}") |
| 575 | 576 | ||
| 576 | # Convert the initializer_token, placeholder_token to ids | 577 | # Convert the initializer_token, placeholder_token to ids |
| 577 | initializer_token_ids = [ | 578 | initializer_token_ids = [ |
| @@ -579,13 +580,15 @@ def main(): | |||
| 579 | for token in args.initializer_token | 580 | for token in args.initializer_token |
| 580 | ] | 581 | ] |
| 581 | 582 | ||
| 582 | new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) | 583 | new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) |
| 583 | embeddings.resize(len(tokenizer)) | 584 | embeddings.resize(len(tokenizer)) |
| 584 | 585 | ||
| 585 | for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): | 586 | init_ratios = [ |
| 586 | embeddings.add_embed(new_token.ids, init_ids) | 587 | embeddings.add_embed(new_id, init_ids) |
| 588 | for (new_id, init_ids) in zip(new_ids, initializer_token_ids) | ||
| 589 | ] | ||
| 587 | 590 | ||
| 588 | print(f"Added {len(new_tokens)} new tokens.") | 591 | print(f"Added {len(new_ids)} new tokens: {zip(args.placeholder_token, new_ids, init_ratios)}") |
| 589 | 592 | ||
| 590 | vae.requires_grad_(False) | 593 | vae.requires_grad_(False) |
| 591 | unet.requires_grad_(False) | 594 | unet.requires_grad_(False) |
| @@ -807,63 +810,16 @@ def main(): | |||
| 807 | def on_eval(): | 810 | def on_eval(): |
| 808 | tokenizer.eval() | 811 | tokenizer.eval() |
| 809 | 812 | ||
| 810 | def loop(step: int, batch, eval: bool = False): | 813 | loop = partial( |
| 811 | # Convert images to latent space | 814 | run_model, |
| 812 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | 815 | vae=vae, |
| 813 | latents = latents * 0.18215 | 816 | noise_scheduler=noise_scheduler, |
| 814 | 817 | unet=unet, | |
| 815 | # Sample noise that we'll add to the latents | 818 | prompt_processor=prompt_processor, |
| 816 | noise = torch.randn_like(latents) | 819 | num_class_images=args.num_class_images, |
| 817 | bsz = latents.shape[0] | 820 | prior_loss_weight=args.prior_loss_weight, |
| 818 | # Sample a random timestep for each image | 821 | seed=args.seed, |
| 819 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed + step) if eval else None | 822 | ) |
| 820 | timesteps = torch.randint( | ||
| 821 | 0, | ||
| 822 | noise_scheduler.config.num_train_timesteps, | ||
| 823 | (bsz,), | ||
| 824 | generator=timesteps_gen, | ||
| 825 | device=latents.device, | ||
| 826 | ) | ||
| 827 | timesteps = timesteps.long() | ||
| 828 | |||
| 829 | # Add noise to the latents according to the noise magnitude at each timestep | ||
| 830 | # (this is the forward diffusion process) | ||
| 831 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
| 832 | |||
| 833 | # Get the text embedding for conditioning | ||
| 834 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) | ||
| 835 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | ||
| 836 | |||
| 837 | # Predict the noise residual | ||
| 838 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
| 839 | |||
| 840 | # Get the target for loss depending on the prediction type | ||
| 841 | if noise_scheduler.config.prediction_type == "epsilon": | ||
| 842 | target = noise | ||
| 843 | elif noise_scheduler.config.prediction_type == "v_prediction": | ||
| 844 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | ||
| 845 | else: | ||
| 846 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | ||
| 847 | |||
| 848 | if args.num_class_images != 0: | ||
| 849 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | ||
| 850 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | ||
| 851 | target, target_prior = torch.chunk(target, 2, dim=0) | ||
| 852 | |||
| 853 | # Compute instance loss | ||
| 854 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
| 855 | |||
| 856 | # Compute prior loss | ||
| 857 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") | ||
| 858 | |||
| 859 | # Add the prior loss to the instance loss. | ||
| 860 | loss = loss + args.prior_loss_weight * prior_loss | ||
| 861 | else: | ||
| 862 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | ||
| 863 | |||
| 864 | acc = (model_pred == target).float().mean() | ||
| 865 | |||
| 866 | return loss, acc, bsz | ||
| 867 | 823 | ||
| 868 | # We need to initialize the trackers we use, and also store our configuration. | 824 | # We need to initialize the trackers we use, and also store our configuration. |
| 869 | # The trackers initializes automatically on the main process. | 825 | # The trackers initializes automatically on the main process. |
