diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-13 07:25:24 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-13 07:25:24 +0100 |
| commit | 89d471652644f449966a0cd944041c98dab7f66c (patch) | |
| tree | 4cc797369a5c781b4978b89a61023c4de7fde606 /train_ti.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-89d471652644f449966a0cd944041c98dab7f66c.tar.gz textual-inversion-diff-89d471652644f449966a0cd944041c98dab7f66c.tar.bz2 textual-inversion-diff-89d471652644f449966a0cd944041c98dab7f66c.zip | |
Code deduplication
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 86 |
1 files changed, 26 insertions, 60 deletions
diff --git a/train_ti.py b/train_ti.py index 9ec5cfb..3b7e3b1 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -13,7 +13,6 @@ from accelerate import Accelerator | |||
| 13 | from accelerate.logging import get_logger | 13 | from accelerate.logging import get_logger |
| 14 | from accelerate.utils import LoggerType, set_seed | 14 | from accelerate.utils import LoggerType, set_seed |
| 15 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel | 15 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel |
| 16 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | ||
| 17 | import matplotlib.pyplot as plt | 16 | import matplotlib.pyplot as plt |
| 18 | from tqdm.auto import tqdm | 17 | from tqdm.auto import tqdm |
| 19 | from transformers import CLIPTextModel | 18 | from transformers import CLIPTextModel |
| @@ -22,8 +21,7 @@ from slugify import slugify | |||
| 22 | from util import load_config, load_embeddings_from_dir | 21 | from util import load_config, load_embeddings_from_dir |
| 23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 24 | from data.csv import VlpnDataModule, VlpnDataItem | 23 | from data.csv import VlpnDataModule, VlpnDataItem |
| 25 | from training.common import loss_step, generate_class_images | 24 | from training.common import loss_step, generate_class_images, get_scheduler |
| 26 | from training.optimization import get_one_cycle_schedule | ||
| 27 | from training.lr import LRFinder | 25 | from training.lr import LRFinder |
| 28 | from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args | 26 | from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args |
| 29 | from models.clip.embeddings import patch_managed_embeddings | 27 | from models.clip.embeddings import patch_managed_embeddings |
| @@ -410,10 +408,16 @@ def parse_args(): | |||
| 410 | help="The weight of prior preservation loss." | 408 | help="The weight of prior preservation loss." |
| 411 | ) | 409 | ) |
| 412 | parser.add_argument( | 410 | parser.add_argument( |
| 413 | "--max_grad_norm", | 411 | "--decay_target", |
| 414 | default=3.0, | 412 | default=0.4, |
| 415 | type=float, | 413 | type=float, |
| 416 | help="Max gradient norm." | 414 | help="Embedding decay target." |
| 415 | ) | ||
| 416 | parser.add_argument( | ||
| 417 | "--decay_factor", | ||
| 418 | default=100, | ||
| 419 | type=float, | ||
| 420 | help="Embedding decay factor." | ||
| 417 | ) | 421 | ) |
| 418 | parser.add_argument( | 422 | parser.add_argument( |
| 419 | "--noise_timesteps", | 423 | "--noise_timesteps", |
| @@ -709,35 +713,6 @@ def main(): | |||
| 709 | ) | 713 | ) |
| 710 | return cond1 and cond3 and cond4 | 714 | return cond1 and cond3 and cond4 |
| 711 | 715 | ||
| 712 | def collate_fn(examples): | ||
| 713 | prompt_ids = [example["prompt_ids"] for example in examples] | ||
| 714 | nprompt_ids = [example["nprompt_ids"] for example in examples] | ||
| 715 | |||
| 716 | input_ids = [example["instance_prompt_ids"] for example in examples] | ||
| 717 | pixel_values = [example["instance_images"] for example in examples] | ||
| 718 | |||
| 719 | # concat class and instance examples for prior preservation | ||
| 720 | if args.num_class_images != 0 and "class_prompt_ids" in examples[0]: | ||
| 721 | input_ids += [example["class_prompt_ids"] for example in examples] | ||
| 722 | pixel_values += [example["class_images"] for example in examples] | ||
| 723 | |||
| 724 | pixel_values = torch.stack(pixel_values) | ||
| 725 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | ||
| 726 | |||
| 727 | prompts = prompt_processor.unify_input_ids(prompt_ids) | ||
| 728 | nprompts = prompt_processor.unify_input_ids(nprompt_ids) | ||
| 729 | inputs = prompt_processor.unify_input_ids(input_ids) | ||
| 730 | |||
| 731 | batch = { | ||
| 732 | "prompt_ids": prompts.input_ids, | ||
| 733 | "nprompt_ids": nprompts.input_ids, | ||
| 734 | "input_ids": inputs.input_ids, | ||
| 735 | "pixel_values": pixel_values, | ||
| 736 | "attention_mask": inputs.attention_mask, | ||
| 737 | } | ||
| 738 | |||
| 739 | return batch | ||
| 740 | |||
| 741 | datamodule = VlpnDataModule( | 716 | datamodule = VlpnDataModule( |
| 742 | data_file=args.train_data_file, | 717 | data_file=args.train_data_file, |
| 743 | batch_size=args.train_batch_size, | 718 | batch_size=args.train_batch_size, |
| @@ -757,7 +732,7 @@ def main(): | |||
| 757 | num_workers=args.dataloader_num_workers, | 732 | num_workers=args.dataloader_num_workers, |
| 758 | seed=args.seed, | 733 | seed=args.seed, |
| 759 | filter=keyword_filter, | 734 | filter=keyword_filter, |
| 760 | collate_fn=collate_fn | 735 | dtype=weight_dtype |
| 761 | ) | 736 | ) |
| 762 | datamodule.setup() | 737 | datamodule.setup() |
| 763 | 738 | ||
| @@ -786,35 +761,23 @@ def main(): | |||
| 786 | overrode_max_train_steps = True | 761 | overrode_max_train_steps = True |
| 787 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 762 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
| 788 | 763 | ||
| 789 | warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps | ||
| 790 | |||
| 791 | if args.find_lr: | 764 | if args.find_lr: |
| 792 | lr_scheduler = None | 765 | lr_scheduler = None |
| 793 | elif args.lr_scheduler == "one_cycle": | ||
| 794 | lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate | ||
| 795 | lr_scheduler = get_one_cycle_schedule( | ||
| 796 | optimizer=optimizer, | ||
| 797 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
| 798 | warmup=args.lr_warmup_func, | ||
| 799 | annealing=args.lr_annealing_func, | ||
| 800 | warmup_exp=args.lr_warmup_exp, | ||
| 801 | annealing_exp=args.lr_annealing_exp, | ||
| 802 | min_lr=lr_min_lr, | ||
| 803 | ) | ||
| 804 | elif args.lr_scheduler == "cosine_with_restarts": | ||
| 805 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | ||
| 806 | optimizer=optimizer, | ||
| 807 | num_warmup_steps=warmup_steps, | ||
| 808 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
| 809 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( | ||
| 810 | ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))), | ||
| 811 | ) | ||
| 812 | else: | 766 | else: |
| 813 | lr_scheduler = get_scheduler( | 767 | lr_scheduler = get_scheduler( |
| 814 | args.lr_scheduler, | 768 | args.lr_scheduler, |
| 815 | optimizer=optimizer, | 769 | optimizer=optimizer, |
| 816 | num_warmup_steps=warmup_steps, | 770 | min_lr=args.lr_min_lr, |
| 817 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 771 | lr=args.learning_rate, |
| 772 | warmup_func=args.lr_warmup_func, | ||
| 773 | annealing_func=args.lr_annealing_func, | ||
| 774 | warmup_exp=args.lr_warmup_exp, | ||
| 775 | annealing_exp=args.lr_annealing_exp, | ||
| 776 | cycles=args.lr_cycles, | ||
| 777 | warmup_epochs=args.lr_warmup_epochs, | ||
| 778 | max_train_steps=args.max_train_steps, | ||
| 779 | num_update_steps_per_epoch=num_update_steps_per_epoch, | ||
| 780 | gradient_accumulation_steps=args.gradient_accumulation_steps | ||
| 818 | ) | 781 | ) |
| 819 | 782 | ||
| 820 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 783 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
| @@ -868,7 +831,10 @@ def main(): | |||
| 868 | 831 | ||
| 869 | @torch.no_grad() | 832 | @torch.no_grad() |
| 870 | def on_after_optimize(lr: float): | 833 | def on_after_optimize(lr: float): |
| 871 | text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) | 834 | text_encoder.text_model.embeddings.normalize( |
| 835 | args.decay_target, | ||
| 836 | min(1.0, args.decay_factor * lr) | ||
| 837 | ) | ||
| 872 | 838 | ||
| 873 | loop = partial( | 839 | loop = partial( |
| 874 | loss_step, | 840 | loss_step, |
