summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py86
1 files changed, 26 insertions, 60 deletions
diff --git a/train_ti.py b/train_ti.py
index 9ec5cfb..3b7e3b1 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -13,7 +13,6 @@ from accelerate import Accelerator
13from accelerate.logging import get_logger 13from accelerate.logging import get_logger
14from accelerate.utils import LoggerType, set_seed 14from accelerate.utils import LoggerType, set_seed
15from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel 15from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
16from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
17import matplotlib.pyplot as plt 16import matplotlib.pyplot as plt
18from tqdm.auto import tqdm 17from tqdm.auto import tqdm
19from transformers import CLIPTextModel 18from transformers import CLIPTextModel
@@ -22,8 +21,7 @@ from slugify import slugify
22from util import load_config, load_embeddings_from_dir 21from util import load_config, load_embeddings_from_dir
23from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 22from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
24from data.csv import VlpnDataModule, VlpnDataItem 23from data.csv import VlpnDataModule, VlpnDataItem
25from training.common import loss_step, generate_class_images 24from training.common import loss_step, generate_class_images, get_scheduler
26from training.optimization import get_one_cycle_schedule
27from training.lr import LRFinder 25from training.lr import LRFinder
28from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args 26from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args
29from models.clip.embeddings import patch_managed_embeddings 27from models.clip.embeddings import patch_managed_embeddings
@@ -410,10 +408,16 @@ def parse_args():
410 help="The weight of prior preservation loss." 408 help="The weight of prior preservation loss."
411 ) 409 )
412 parser.add_argument( 410 parser.add_argument(
413 "--max_grad_norm", 411 "--decay_target",
414 default=3.0, 412 default=0.4,
415 type=float, 413 type=float,
416 help="Max gradient norm." 414 help="Embedding decay target."
415 )
416 parser.add_argument(
417 "--decay_factor",
418 default=100,
419 type=float,
420 help="Embedding decay factor."
417 ) 421 )
418 parser.add_argument( 422 parser.add_argument(
419 "--noise_timesteps", 423 "--noise_timesteps",
@@ -709,35 +713,6 @@ def main():
709 ) 713 )
710 return cond1 and cond3 and cond4 714 return cond1 and cond3 and cond4
711 715
712 def collate_fn(examples):
713 prompt_ids = [example["prompt_ids"] for example in examples]
714 nprompt_ids = [example["nprompt_ids"] for example in examples]
715
716 input_ids = [example["instance_prompt_ids"] for example in examples]
717 pixel_values = [example["instance_images"] for example in examples]
718
719 # concat class and instance examples for prior preservation
720 if args.num_class_images != 0 and "class_prompt_ids" in examples[0]:
721 input_ids += [example["class_prompt_ids"] for example in examples]
722 pixel_values += [example["class_images"] for example in examples]
723
724 pixel_values = torch.stack(pixel_values)
725 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
726
727 prompts = prompt_processor.unify_input_ids(prompt_ids)
728 nprompts = prompt_processor.unify_input_ids(nprompt_ids)
729 inputs = prompt_processor.unify_input_ids(input_ids)
730
731 batch = {
732 "prompt_ids": prompts.input_ids,
733 "nprompt_ids": nprompts.input_ids,
734 "input_ids": inputs.input_ids,
735 "pixel_values": pixel_values,
736 "attention_mask": inputs.attention_mask,
737 }
738
739 return batch
740
741 datamodule = VlpnDataModule( 716 datamodule = VlpnDataModule(
742 data_file=args.train_data_file, 717 data_file=args.train_data_file,
743 batch_size=args.train_batch_size, 718 batch_size=args.train_batch_size,
@@ -757,7 +732,7 @@ def main():
757 num_workers=args.dataloader_num_workers, 732 num_workers=args.dataloader_num_workers,
758 seed=args.seed, 733 seed=args.seed,
759 filter=keyword_filter, 734 filter=keyword_filter,
760 collate_fn=collate_fn 735 dtype=weight_dtype
761 ) 736 )
762 datamodule.setup() 737 datamodule.setup()
763 738
@@ -786,35 +761,23 @@ def main():
786 overrode_max_train_steps = True 761 overrode_max_train_steps = True
787 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 762 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
788 763
789 warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps
790
791 if args.find_lr: 764 if args.find_lr:
792 lr_scheduler = None 765 lr_scheduler = None
793 elif args.lr_scheduler == "one_cycle":
794 lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate
795 lr_scheduler = get_one_cycle_schedule(
796 optimizer=optimizer,
797 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
798 warmup=args.lr_warmup_func,
799 annealing=args.lr_annealing_func,
800 warmup_exp=args.lr_warmup_exp,
801 annealing_exp=args.lr_annealing_exp,
802 min_lr=lr_min_lr,
803 )
804 elif args.lr_scheduler == "cosine_with_restarts":
805 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
806 optimizer=optimizer,
807 num_warmup_steps=warmup_steps,
808 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
809 num_cycles=args.lr_cycles or math.ceil(math.sqrt(
810 ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))),
811 )
812 else: 766 else:
813 lr_scheduler = get_scheduler( 767 lr_scheduler = get_scheduler(
814 args.lr_scheduler, 768 args.lr_scheduler,
815 optimizer=optimizer, 769 optimizer=optimizer,
816 num_warmup_steps=warmup_steps, 770 min_lr=args.lr_min_lr,
817 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 771 lr=args.learning_rate,
772 warmup_func=args.lr_warmup_func,
773 annealing_func=args.lr_annealing_func,
774 warmup_exp=args.lr_warmup_exp,
775 annealing_exp=args.lr_annealing_exp,
776 cycles=args.lr_cycles,
777 warmup_epochs=args.lr_warmup_epochs,
778 max_train_steps=args.max_train_steps,
779 num_update_steps_per_epoch=num_update_steps_per_epoch,
780 gradient_accumulation_steps=args.gradient_accumulation_steps
818 ) 781 )
819 782
820 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 783 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
@@ -868,7 +831,10 @@ def main():
868 831
869 @torch.no_grad() 832 @torch.no_grad()
870 def on_after_optimize(lr: float): 833 def on_after_optimize(lr: float):
871 text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) 834 text_encoder.text_model.embeddings.normalize(
835 args.decay_target,
836 min(1.0, args.decay_factor * lr)
837 )
872 838
873 loop = partial( 839 loop = partial(
874 loss_step, 840 loss_step,