diff options
author | Volpeon <git@volpeon.ink> | 2023-01-13 07:25:24 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-13 07:25:24 +0100 |
commit | 89d471652644f449966a0cd944041c98dab7f66c (patch) | |
tree | 4cc797369a5c781b4978b89a61023c4de7fde606 /train_dreambooth.py | |
parent | Update (diff) | |
download | textual-inversion-diff-89d471652644f449966a0cd944041c98dab7f66c.tar.gz textual-inversion-diff-89d471652644f449966a0cd944041c98dab7f66c.tar.bz2 textual-inversion-diff-89d471652644f449966a0cd944041c98dab7f66c.zip |
Code deduplication
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 71 |
1 files changed, 15 insertions, 56 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index ebcf802..da3a075 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -14,7 +14,6 @@ from accelerate import Accelerator | |||
14 | from accelerate.logging import get_logger | 14 | from accelerate.logging import get_logger |
15 | from accelerate.utils import LoggerType, set_seed | 15 | from accelerate.utils import LoggerType, set_seed |
16 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel | 16 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel |
17 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | ||
18 | import matplotlib.pyplot as plt | 17 | import matplotlib.pyplot as plt |
19 | from diffusers.training_utils import EMAModel | 18 | from diffusers.training_utils import EMAModel |
20 | from tqdm.auto import tqdm | 19 | from tqdm.auto import tqdm |
@@ -24,8 +23,7 @@ from slugify import slugify | |||
24 | from util import load_config, load_embeddings_from_dir | 23 | from util import load_config, load_embeddings_from_dir |
25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
26 | from data.csv import VlpnDataModule, VlpnDataItem | 25 | from data.csv import VlpnDataModule, VlpnDataItem |
27 | from training.common import loss_step, generate_class_images | 26 | from training.common import loss_step, generate_class_images, get_scheduler |
28 | from training.optimization import get_one_cycle_schedule | ||
29 | from training.lr import LRFinder | 27 | from training.lr import LRFinder |
30 | from training.util import AverageMeter, CheckpointerBase, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, save_args |
31 | from models.clip.embeddings import patch_managed_embeddings | 29 | from models.clip.embeddings import patch_managed_embeddings |
@@ -750,35 +748,6 @@ def main(): | |||
750 | ) | 748 | ) |
751 | return cond3 and cond4 | 749 | return cond3 and cond4 |
752 | 750 | ||
753 | def collate_fn(examples): | ||
754 | prompt_ids = [example["prompt_ids"] for example in examples] | ||
755 | nprompt_ids = [example["nprompt_ids"] for example in examples] | ||
756 | |||
757 | input_ids = [example["instance_prompt_ids"] for example in examples] | ||
758 | pixel_values = [example["instance_images"] for example in examples] | ||
759 | |||
760 | # concat class and instance examples for prior preservation | ||
761 | if args.num_class_images != 0 and "class_prompt_ids" in examples[0]: | ||
762 | input_ids += [example["class_prompt_ids"] for example in examples] | ||
763 | pixel_values += [example["class_images"] for example in examples] | ||
764 | |||
765 | pixel_values = torch.stack(pixel_values) | ||
766 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | ||
767 | |||
768 | prompts = prompt_processor.unify_input_ids(prompt_ids) | ||
769 | nprompts = prompt_processor.unify_input_ids(nprompt_ids) | ||
770 | inputs = prompt_processor.unify_input_ids(input_ids) | ||
771 | |||
772 | batch = { | ||
773 | "prompt_ids": prompts.input_ids, | ||
774 | "nprompt_ids": nprompts.input_ids, | ||
775 | "input_ids": inputs.input_ids, | ||
776 | "pixel_values": pixel_values, | ||
777 | "attention_mask": inputs.attention_mask, | ||
778 | } | ||
779 | |||
780 | return batch | ||
781 | |||
782 | datamodule = VlpnDataModule( | 751 | datamodule = VlpnDataModule( |
783 | data_file=args.train_data_file, | 752 | data_file=args.train_data_file, |
784 | batch_size=args.train_batch_size, | 753 | batch_size=args.train_batch_size, |
@@ -798,7 +767,7 @@ def main(): | |||
798 | num_workers=args.dataloader_num_workers, | 767 | num_workers=args.dataloader_num_workers, |
799 | seed=args.seed, | 768 | seed=args.seed, |
800 | filter=keyword_filter, | 769 | filter=keyword_filter, |
801 | collate_fn=collate_fn | 770 | dtype=weight_dtype |
802 | ) | 771 | ) |
803 | 772 | ||
804 | datamodule.prepare_data() | 773 | datamodule.prepare_data() |
@@ -829,33 +798,23 @@ def main(): | |||
829 | overrode_max_train_steps = True | 798 | overrode_max_train_steps = True |
830 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 799 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
831 | 800 | ||
832 | warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps | 801 | if args.find_lr: |
833 | 802 | lr_scheduler = None | |
834 | if args.lr_scheduler == "one_cycle": | ||
835 | lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate | ||
836 | lr_scheduler = get_one_cycle_schedule( | ||
837 | optimizer=optimizer, | ||
838 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
839 | warmup=args.lr_warmup_func, | ||
840 | annealing=args.lr_annealing_func, | ||
841 | warmup_exp=args.lr_warmup_exp, | ||
842 | annealing_exp=args.lr_annealing_exp, | ||
843 | min_lr=lr_min_lr, | ||
844 | ) | ||
845 | elif args.lr_scheduler == "cosine_with_restarts": | ||
846 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | ||
847 | optimizer=optimizer, | ||
848 | num_warmup_steps=warmup_steps, | ||
849 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
850 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( | ||
851 | ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))), | ||
852 | ) | ||
853 | else: | 803 | else: |
854 | lr_scheduler = get_scheduler( | 804 | lr_scheduler = get_scheduler( |
855 | args.lr_scheduler, | 805 | args.lr_scheduler, |
856 | optimizer=optimizer, | 806 | optimizer=optimizer, |
857 | num_warmup_steps=warmup_steps, | 807 | min_lr=args.lr_min_lr, |
858 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 808 | lr=args.learning_rate, |
809 | warmup_func=args.lr_warmup_func, | ||
810 | annealing_func=args.lr_annealing_func, | ||
811 | warmup_exp=args.lr_warmup_exp, | ||
812 | annealing_exp=args.lr_annealing_exp, | ||
813 | cycles=args.lr_cycles, | ||
814 | warmup_epochs=args.lr_warmup_epochs, | ||
815 | max_train_steps=args.max_train_steps, | ||
816 | num_update_steps_per_epoch=num_update_steps_per_epoch, | ||
817 | gradient_accumulation_steps=args.gradient_accumulation_steps | ||
859 | ) | 818 | ) |
860 | 819 | ||
861 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 820 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |