summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-13 07:25:24 +0100
committerVolpeon <git@volpeon.ink>2023-01-13 07:25:24 +0100
commit89d471652644f449966a0cd944041c98dab7f66c (patch)
tree4cc797369a5c781b4978b89a61023c4de7fde606 /train_dreambooth.py
parentUpdate (diff)
downloadtextual-inversion-diff-89d471652644f449966a0cd944041c98dab7f66c.tar.gz
textual-inversion-diff-89d471652644f449966a0cd944041c98dab7f66c.tar.bz2
textual-inversion-diff-89d471652644f449966a0cd944041c98dab7f66c.zip
Code deduplication
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py71
1 files changed, 15 insertions, 56 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index ebcf802..da3a075 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -14,7 +14,6 @@ from accelerate import Accelerator
14from accelerate.logging import get_logger 14from accelerate.logging import get_logger
15from accelerate.utils import LoggerType, set_seed 15from accelerate.utils import LoggerType, set_seed
16from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel 16from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
17from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
18import matplotlib.pyplot as plt 17import matplotlib.pyplot as plt
19from diffusers.training_utils import EMAModel 18from diffusers.training_utils import EMAModel
20from tqdm.auto import tqdm 19from tqdm.auto import tqdm
@@ -24,8 +23,7 @@ from slugify import slugify
24from util import load_config, load_embeddings_from_dir 23from util import load_config, load_embeddings_from_dir
25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
26from data.csv import VlpnDataModule, VlpnDataItem 25from data.csv import VlpnDataModule, VlpnDataItem
27from training.common import loss_step, generate_class_images 26from training.common import loss_step, generate_class_images, get_scheduler
28from training.optimization import get_one_cycle_schedule
29from training.lr import LRFinder 27from training.lr import LRFinder
30from training.util import AverageMeter, CheckpointerBase, save_args 28from training.util import AverageMeter, CheckpointerBase, save_args
31from models.clip.embeddings import patch_managed_embeddings 29from models.clip.embeddings import patch_managed_embeddings
@@ -750,35 +748,6 @@ def main():
750 ) 748 )
751 return cond3 and cond4 749 return cond3 and cond4
752 750
753 def collate_fn(examples):
754 prompt_ids = [example["prompt_ids"] for example in examples]
755 nprompt_ids = [example["nprompt_ids"] for example in examples]
756
757 input_ids = [example["instance_prompt_ids"] for example in examples]
758 pixel_values = [example["instance_images"] for example in examples]
759
760 # concat class and instance examples for prior preservation
761 if args.num_class_images != 0 and "class_prompt_ids" in examples[0]:
762 input_ids += [example["class_prompt_ids"] for example in examples]
763 pixel_values += [example["class_images"] for example in examples]
764
765 pixel_values = torch.stack(pixel_values)
766 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
767
768 prompts = prompt_processor.unify_input_ids(prompt_ids)
769 nprompts = prompt_processor.unify_input_ids(nprompt_ids)
770 inputs = prompt_processor.unify_input_ids(input_ids)
771
772 batch = {
773 "prompt_ids": prompts.input_ids,
774 "nprompt_ids": nprompts.input_ids,
775 "input_ids": inputs.input_ids,
776 "pixel_values": pixel_values,
777 "attention_mask": inputs.attention_mask,
778 }
779
780 return batch
781
782 datamodule = VlpnDataModule( 751 datamodule = VlpnDataModule(
783 data_file=args.train_data_file, 752 data_file=args.train_data_file,
784 batch_size=args.train_batch_size, 753 batch_size=args.train_batch_size,
@@ -798,7 +767,7 @@ def main():
798 num_workers=args.dataloader_num_workers, 767 num_workers=args.dataloader_num_workers,
799 seed=args.seed, 768 seed=args.seed,
800 filter=keyword_filter, 769 filter=keyword_filter,
801 collate_fn=collate_fn 770 dtype=weight_dtype
802 ) 771 )
803 772
804 datamodule.prepare_data() 773 datamodule.prepare_data()
@@ -829,33 +798,23 @@ def main():
829 overrode_max_train_steps = True 798 overrode_max_train_steps = True
830 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 799 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
831 800
832 warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps 801 if args.find_lr:
833 802 lr_scheduler = None
834 if args.lr_scheduler == "one_cycle":
835 lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate
836 lr_scheduler = get_one_cycle_schedule(
837 optimizer=optimizer,
838 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
839 warmup=args.lr_warmup_func,
840 annealing=args.lr_annealing_func,
841 warmup_exp=args.lr_warmup_exp,
842 annealing_exp=args.lr_annealing_exp,
843 min_lr=lr_min_lr,
844 )
845 elif args.lr_scheduler == "cosine_with_restarts":
846 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
847 optimizer=optimizer,
848 num_warmup_steps=warmup_steps,
849 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
850 num_cycles=args.lr_cycles or math.ceil(math.sqrt(
851 ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))),
852 )
853 else: 803 else:
854 lr_scheduler = get_scheduler( 804 lr_scheduler = get_scheduler(
855 args.lr_scheduler, 805 args.lr_scheduler,
856 optimizer=optimizer, 806 optimizer=optimizer,
857 num_warmup_steps=warmup_steps, 807 min_lr=args.lr_min_lr,
858 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 808 lr=args.learning_rate,
809 warmup_func=args.lr_warmup_func,
810 annealing_func=args.lr_annealing_func,
811 warmup_exp=args.lr_warmup_exp,
812 annealing_exp=args.lr_annealing_exp,
813 cycles=args.lr_cycles,
814 warmup_epochs=args.lr_warmup_epochs,
815 max_train_steps=args.max_train_steps,
816 num_update_steps_per_epoch=num_update_steps_per_epoch,
817 gradient_accumulation_steps=args.gradient_accumulation_steps
859 ) 818 )
860 819
861 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 820 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(