diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 86 |
1 files changed, 26 insertions, 60 deletions
diff --git a/train_ti.py b/train_ti.py index 9ec5cfb..3b7e3b1 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -13,7 +13,6 @@ from accelerate import Accelerator | |||
13 | from accelerate.logging import get_logger | 13 | from accelerate.logging import get_logger |
14 | from accelerate.utils import LoggerType, set_seed | 14 | from accelerate.utils import LoggerType, set_seed |
15 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel | 15 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel |
16 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | ||
17 | import matplotlib.pyplot as plt | 16 | import matplotlib.pyplot as plt |
18 | from tqdm.auto import tqdm | 17 | from tqdm.auto import tqdm |
19 | from transformers import CLIPTextModel | 18 | from transformers import CLIPTextModel |
@@ -22,8 +21,7 @@ from slugify import slugify | |||
22 | from util import load_config, load_embeddings_from_dir | 21 | from util import load_config, load_embeddings_from_dir |
23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 22 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
24 | from data.csv import VlpnDataModule, VlpnDataItem | 23 | from data.csv import VlpnDataModule, VlpnDataItem |
25 | from training.common import loss_step, generate_class_images | 24 | from training.common import loss_step, generate_class_images, get_scheduler |
26 | from training.optimization import get_one_cycle_schedule | ||
27 | from training.lr import LRFinder | 25 | from training.lr import LRFinder |
28 | from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args | 26 | from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args |
29 | from models.clip.embeddings import patch_managed_embeddings | 27 | from models.clip.embeddings import patch_managed_embeddings |
@@ -410,10 +408,16 @@ def parse_args(): | |||
410 | help="The weight of prior preservation loss." | 408 | help="The weight of prior preservation loss." |
411 | ) | 409 | ) |
412 | parser.add_argument( | 410 | parser.add_argument( |
413 | "--max_grad_norm", | 411 | "--decay_target", |
414 | default=3.0, | 412 | default=0.4, |
415 | type=float, | 413 | type=float, |
416 | help="Max gradient norm." | 414 | help="Embedding decay target." |
415 | ) | ||
416 | parser.add_argument( | ||
417 | "--decay_factor", | ||
418 | default=100, | ||
419 | type=float, | ||
420 | help="Embedding decay factor." | ||
417 | ) | 421 | ) |
418 | parser.add_argument( | 422 | parser.add_argument( |
419 | "--noise_timesteps", | 423 | "--noise_timesteps", |
@@ -709,35 +713,6 @@ def main(): | |||
709 | ) | 713 | ) |
710 | return cond1 and cond3 and cond4 | 714 | return cond1 and cond3 and cond4 |
711 | 715 | ||
712 | def collate_fn(examples): | ||
713 | prompt_ids = [example["prompt_ids"] for example in examples] | ||
714 | nprompt_ids = [example["nprompt_ids"] for example in examples] | ||
715 | |||
716 | input_ids = [example["instance_prompt_ids"] for example in examples] | ||
717 | pixel_values = [example["instance_images"] for example in examples] | ||
718 | |||
719 | # concat class and instance examples for prior preservation | ||
720 | if args.num_class_images != 0 and "class_prompt_ids" in examples[0]: | ||
721 | input_ids += [example["class_prompt_ids"] for example in examples] | ||
722 | pixel_values += [example["class_images"] for example in examples] | ||
723 | |||
724 | pixel_values = torch.stack(pixel_values) | ||
725 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) | ||
726 | |||
727 | prompts = prompt_processor.unify_input_ids(prompt_ids) | ||
728 | nprompts = prompt_processor.unify_input_ids(nprompt_ids) | ||
729 | inputs = prompt_processor.unify_input_ids(input_ids) | ||
730 | |||
731 | batch = { | ||
732 | "prompt_ids": prompts.input_ids, | ||
733 | "nprompt_ids": nprompts.input_ids, | ||
734 | "input_ids": inputs.input_ids, | ||
735 | "pixel_values": pixel_values, | ||
736 | "attention_mask": inputs.attention_mask, | ||
737 | } | ||
738 | |||
739 | return batch | ||
740 | |||
741 | datamodule = VlpnDataModule( | 716 | datamodule = VlpnDataModule( |
742 | data_file=args.train_data_file, | 717 | data_file=args.train_data_file, |
743 | batch_size=args.train_batch_size, | 718 | batch_size=args.train_batch_size, |
@@ -757,7 +732,7 @@ def main(): | |||
757 | num_workers=args.dataloader_num_workers, | 732 | num_workers=args.dataloader_num_workers, |
758 | seed=args.seed, | 733 | seed=args.seed, |
759 | filter=keyword_filter, | 734 | filter=keyword_filter, |
760 | collate_fn=collate_fn | 735 | dtype=weight_dtype |
761 | ) | 736 | ) |
762 | datamodule.setup() | 737 | datamodule.setup() |
763 | 738 | ||
@@ -786,35 +761,23 @@ def main(): | |||
786 | overrode_max_train_steps = True | 761 | overrode_max_train_steps = True |
787 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | 762 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
788 | 763 | ||
789 | warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps | ||
790 | |||
791 | if args.find_lr: | 764 | if args.find_lr: |
792 | lr_scheduler = None | 765 | lr_scheduler = None |
793 | elif args.lr_scheduler == "one_cycle": | ||
794 | lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate | ||
795 | lr_scheduler = get_one_cycle_schedule( | ||
796 | optimizer=optimizer, | ||
797 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
798 | warmup=args.lr_warmup_func, | ||
799 | annealing=args.lr_annealing_func, | ||
800 | warmup_exp=args.lr_warmup_exp, | ||
801 | annealing_exp=args.lr_annealing_exp, | ||
802 | min_lr=lr_min_lr, | ||
803 | ) | ||
804 | elif args.lr_scheduler == "cosine_with_restarts": | ||
805 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | ||
806 | optimizer=optimizer, | ||
807 | num_warmup_steps=warmup_steps, | ||
808 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | ||
809 | num_cycles=args.lr_cycles or math.ceil(math.sqrt( | ||
810 | ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))), | ||
811 | ) | ||
812 | else: | 766 | else: |
813 | lr_scheduler = get_scheduler( | 767 | lr_scheduler = get_scheduler( |
814 | args.lr_scheduler, | 768 | args.lr_scheduler, |
815 | optimizer=optimizer, | 769 | optimizer=optimizer, |
816 | num_warmup_steps=warmup_steps, | 770 | min_lr=args.lr_min_lr, |
817 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 771 | lr=args.learning_rate, |
772 | warmup_func=args.lr_warmup_func, | ||
773 | annealing_func=args.lr_annealing_func, | ||
774 | warmup_exp=args.lr_warmup_exp, | ||
775 | annealing_exp=args.lr_annealing_exp, | ||
776 | cycles=args.lr_cycles, | ||
777 | warmup_epochs=args.lr_warmup_epochs, | ||
778 | max_train_steps=args.max_train_steps, | ||
779 | num_update_steps_per_epoch=num_update_steps_per_epoch, | ||
780 | gradient_accumulation_steps=args.gradient_accumulation_steps | ||
818 | ) | 781 | ) |
819 | 782 | ||
820 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 783 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
@@ -868,7 +831,10 @@ def main(): | |||
868 | 831 | ||
869 | @torch.no_grad() | 832 | @torch.no_grad() |
870 | def on_after_optimize(lr: float): | 833 | def on_after_optimize(lr: float): |
871 | text_encoder.text_model.embeddings.normalize(min(1.0, 100 * lr)) | 834 | text_encoder.text_model.embeddings.normalize( |
835 | args.decay_target, | ||
836 | min(1.0, args.decay_factor * lr) | ||
837 | ) | ||
872 | 838 | ||
873 | loop = partial( | 839 | loop = partial( |
874 | loss_step, | 840 | loss_step, |