diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 94 |
1 files changed, 46 insertions, 48 deletions
diff --git a/train_ti.py b/train_ti.py index 0ffc9e6..89c6672 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -21,7 +21,7 @@ from slugify import slugify | |||
21 | 21 | ||
22 | from util import load_config, load_embeddings_from_dir | 22 | from util import load_config, load_embeddings_from_dir |
23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
24 | from data.csv import CSVDataModule, CSVDataItem | 24 | from data.csv import VlpnDataModule, VlpnDataItem |
25 | from training.common import run_model | 25 | from training.common import run_model |
26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
27 | from training.lr import LRFinder | 27 | from training.lr import LRFinder |
@@ -146,11 +146,6 @@ def parse_args(): | |||
146 | ), | 146 | ), |
147 | ) | 147 | ) |
148 | parser.add_argument( | 148 | parser.add_argument( |
149 | "--center_crop", | ||
150 | action="store_true", | ||
151 | help="Whether to center crop images before resizing to resolution" | ||
152 | ) | ||
153 | parser.add_argument( | ||
154 | "--tag_dropout", | 149 | "--tag_dropout", |
155 | type=float, | 150 | type=float, |
156 | default=0.1, | 151 | default=0.1, |
@@ -668,7 +663,7 @@ def main(): | |||
668 | elif args.mixed_precision == "bf16": | 663 | elif args.mixed_precision == "bf16": |
669 | weight_dtype = torch.bfloat16 | 664 | weight_dtype = torch.bfloat16 |
670 | 665 | ||
671 | def keyword_filter(item: CSVDataItem): | 666 | def keyword_filter(item: VlpnDataItem): |
672 | cond1 = any( | 667 | cond1 = any( |
673 | keyword in part | 668 | keyword in part |
674 | for keyword in args.placeholder_token | 669 | for keyword in args.placeholder_token |
@@ -708,7 +703,7 @@ def main(): | |||
708 | } | 703 | } |
709 | return batch | 704 | return batch |
710 | 705 | ||
711 | datamodule = CSVDataModule( | 706 | datamodule = VlpnDataModule( |
712 | data_file=args.train_data_file, | 707 | data_file=args.train_data_file, |
713 | batch_size=args.train_batch_size, | 708 | batch_size=args.train_batch_size, |
714 | prompt_processor=prompt_processor, | 709 | prompt_processor=prompt_processor, |
@@ -717,7 +712,6 @@ def main(): | |||
717 | size=args.resolution, | 712 | size=args.resolution, |
718 | repeats=args.repeats, | 713 | repeats=args.repeats, |
719 | dropout=args.tag_dropout, | 714 | dropout=args.tag_dropout, |
720 | center_crop=args.center_crop, | ||
721 | template_key=args.train_data_template, | 715 | template_key=args.train_data_template, |
722 | valid_set_size=args.valid_set_size, | 716 | valid_set_size=args.valid_set_size, |
723 | num_workers=args.dataloader_num_workers, | 717 | num_workers=args.dataloader_num_workers, |
@@ -725,8 +719,6 @@ def main(): | |||
725 | filter=keyword_filter, | 719 | filter=keyword_filter, |
726 | collate_fn=collate_fn | 720 | collate_fn=collate_fn |
727 | ) | 721 | ) |
728 | |||
729 | datamodule.prepare_data() | ||
730 | datamodule.setup() | 722 | datamodule.setup() |
731 | 723 | ||
732 | if args.num_class_images != 0: | 724 | if args.num_class_images != 0: |
@@ -769,12 +761,14 @@ def main(): | |||
769 | if torch.cuda.is_available(): | 761 | if torch.cuda.is_available(): |
770 | torch.cuda.empty_cache() | 762 | torch.cuda.empty_cache() |
771 | 763 | ||
772 | train_dataloader = datamodule.train_dataloader() | 764 | train_dataloaders = datamodule.train_dataloaders |
773 | val_dataloader = datamodule.val_dataloader() | 765 | default_train_dataloader = train_dataloaders[0] |
766 | val_dataloader = datamodule.val_dataloader | ||
774 | 767 | ||
775 | # Scheduler and math around the number of training steps. | 768 | # Scheduler and math around the number of training steps. |
776 | overrode_max_train_steps = False | 769 | overrode_max_train_steps = False |
777 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | 770 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) |
771 | num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) | ||
778 | if args.max_train_steps is None: | 772 | if args.max_train_steps is None: |
779 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 773 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
780 | overrode_max_train_steps = True | 774 | overrode_max_train_steps = True |
@@ -811,9 +805,10 @@ def main(): | |||
811 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 805 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
812 | ) | 806 | ) |
813 | 807 | ||
814 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 808 | text_encoder, optimizer, val_dataloader, lr_scheduler = accelerator.prepare( |
815 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | 809 | text_encoder, optimizer, val_dataloader, lr_scheduler |
816 | ) | 810 | ) |
811 | train_dataloaders = accelerator.prepare(*train_dataloaders) | ||
817 | 812 | ||
818 | # Move vae and unet to device | 813 | # Move vae and unet to device |
819 | vae.to(accelerator.device, dtype=weight_dtype) | 814 | vae.to(accelerator.device, dtype=weight_dtype) |
@@ -831,7 +826,8 @@ def main(): | |||
831 | unet.eval() | 826 | unet.eval() |
832 | 827 | ||
833 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | 828 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. |
834 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | 829 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) |
830 | num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) | ||
835 | if overrode_max_train_steps: | 831 | if overrode_max_train_steps: |
836 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 832 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
837 | 833 | ||
@@ -889,7 +885,7 @@ def main(): | |||
889 | accelerator, | 885 | accelerator, |
890 | text_encoder, | 886 | text_encoder, |
891 | optimizer, | 887 | optimizer, |
892 | train_dataloader, | 888 | default_train_dataloader, |
893 | val_dataloader, | 889 | val_dataloader, |
894 | loop, | 890 | loop, |
895 | on_train=on_train, | 891 | on_train=on_train, |
@@ -968,46 +964,48 @@ def main(): | |||
968 | text_encoder.train() | 964 | text_encoder.train() |
969 | 965 | ||
970 | with on_train(): | 966 | with on_train(): |
971 | for step, batch in enumerate(train_dataloader): | 967 | for train_dataloader in train_dataloaders: |
972 | with accelerator.accumulate(text_encoder): | 968 | for step, batch in enumerate(train_dataloader): |
973 | loss, acc, bsz = loop(step, batch) | 969 | with accelerator.accumulate(text_encoder): |
970 | loss, acc, bsz = loop(step, batch) | ||
974 | 971 | ||
975 | accelerator.backward(loss) | 972 | accelerator.backward(loss) |
976 | 973 | ||
977 | optimizer.step() | 974 | optimizer.step() |
978 | if not accelerator.optimizer_step_was_skipped: | 975 | if not accelerator.optimizer_step_was_skipped: |
979 | lr_scheduler.step() | 976 | lr_scheduler.step() |
980 | optimizer.zero_grad(set_to_none=True) | 977 | optimizer.zero_grad(set_to_none=True) |
981 | 978 | ||
982 | avg_loss.update(loss.detach_(), bsz) | 979 | avg_loss.update(loss.detach_(), bsz) |
983 | avg_acc.update(acc.detach_(), bsz) | 980 | avg_acc.update(acc.detach_(), bsz) |
984 | 981 | ||
985 | # Checks if the accelerator has performed an optimization step behind the scenes | 982 | # Checks if the accelerator has performed an optimization step behind the scenes |
986 | if accelerator.sync_gradients: | 983 | if accelerator.sync_gradients: |
987 | if args.use_ema: | 984 | if args.use_ema: |
988 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 985 | ema_embeddings.step( |
986 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
989 | 987 | ||
990 | local_progress_bar.update(1) | 988 | local_progress_bar.update(1) |
991 | global_progress_bar.update(1) | 989 | global_progress_bar.update(1) |
992 | 990 | ||
993 | global_step += 1 | 991 | global_step += 1 |
994 | 992 | ||
995 | logs = { | 993 | logs = { |
996 | "train/loss": avg_loss.avg.item(), | 994 | "train/loss": avg_loss.avg.item(), |
997 | "train/acc": avg_acc.avg.item(), | 995 | "train/acc": avg_acc.avg.item(), |
998 | "train/cur_loss": loss.item(), | 996 | "train/cur_loss": loss.item(), |
999 | "train/cur_acc": acc.item(), | 997 | "train/cur_acc": acc.item(), |
1000 | "lr": lr_scheduler.get_last_lr()[0], | 998 | "lr": lr_scheduler.get_last_lr()[0], |
1001 | } | 999 | } |
1002 | if args.use_ema: | 1000 | if args.use_ema: |
1003 | logs["ema_decay"] = ema_embeddings.decay | 1001 | logs["ema_decay"] = ema_embeddings.decay |
1004 | 1002 | ||
1005 | accelerator.log(logs, step=global_step) | 1003 | accelerator.log(logs, step=global_step) |
1006 | 1004 | ||
1007 | local_progress_bar.set_postfix(**logs) | 1005 | local_progress_bar.set_postfix(**logs) |
1008 | 1006 | ||
1009 | if global_step >= args.max_train_steps: | 1007 | if global_step >= args.max_train_steps: |
1010 | break | 1008 | break |
1011 | 1009 | ||
1012 | accelerator.wait_for_everyone() | 1010 | accelerator.wait_for_everyone() |
1013 | 1011 | ||