summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-07 13:57:46 +0100
committerVolpeon <git@volpeon.ink>2023-01-07 13:57:46 +0100
commit3ee13893f9a4973ac75f45fe9318c35760dd4b1f (patch)
treee652a54e6c241eef52ddb30f2d7048da8f306f7b /train_ti.py
parentUpdate (diff)
downloadtextual-inversion-diff-3ee13893f9a4973ac75f45fe9318c35760dd4b1f.tar.gz
textual-inversion-diff-3ee13893f9a4973ac75f45fe9318c35760dd4b1f.tar.bz2
textual-inversion-diff-3ee13893f9a4973ac75f45fe9318c35760dd4b1f.zip
Added progressive aspect ratio bucketing
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py94
1 files changed, 46 insertions, 48 deletions
diff --git a/train_ti.py b/train_ti.py
index 0ffc9e6..89c6672 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -21,7 +21,7 @@ from slugify import slugify
21 21
22from util import load_config, load_embeddings_from_dir 22from util import load_config, load_embeddings_from_dir
23from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 23from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
24from data.csv import CSVDataModule, CSVDataItem 24from data.csv import VlpnDataModule, VlpnDataItem
25from training.common import run_model 25from training.common import run_model
26from training.optimization import get_one_cycle_schedule 26from training.optimization import get_one_cycle_schedule
27from training.lr import LRFinder 27from training.lr import LRFinder
@@ -146,11 +146,6 @@ def parse_args():
146 ), 146 ),
147 ) 147 )
148 parser.add_argument( 148 parser.add_argument(
149 "--center_crop",
150 action="store_true",
151 help="Whether to center crop images before resizing to resolution"
152 )
153 parser.add_argument(
154 "--tag_dropout", 149 "--tag_dropout",
155 type=float, 150 type=float,
156 default=0.1, 151 default=0.1,
@@ -668,7 +663,7 @@ def main():
668 elif args.mixed_precision == "bf16": 663 elif args.mixed_precision == "bf16":
669 weight_dtype = torch.bfloat16 664 weight_dtype = torch.bfloat16
670 665
671 def keyword_filter(item: CSVDataItem): 666 def keyword_filter(item: VlpnDataItem):
672 cond1 = any( 667 cond1 = any(
673 keyword in part 668 keyword in part
674 for keyword in args.placeholder_token 669 for keyword in args.placeholder_token
@@ -708,7 +703,7 @@ def main():
708 } 703 }
709 return batch 704 return batch
710 705
711 datamodule = CSVDataModule( 706 datamodule = VlpnDataModule(
712 data_file=args.train_data_file, 707 data_file=args.train_data_file,
713 batch_size=args.train_batch_size, 708 batch_size=args.train_batch_size,
714 prompt_processor=prompt_processor, 709 prompt_processor=prompt_processor,
@@ -717,7 +712,6 @@ def main():
717 size=args.resolution, 712 size=args.resolution,
718 repeats=args.repeats, 713 repeats=args.repeats,
719 dropout=args.tag_dropout, 714 dropout=args.tag_dropout,
720 center_crop=args.center_crop,
721 template_key=args.train_data_template, 715 template_key=args.train_data_template,
722 valid_set_size=args.valid_set_size, 716 valid_set_size=args.valid_set_size,
723 num_workers=args.dataloader_num_workers, 717 num_workers=args.dataloader_num_workers,
@@ -725,8 +719,6 @@ def main():
725 filter=keyword_filter, 719 filter=keyword_filter,
726 collate_fn=collate_fn 720 collate_fn=collate_fn
727 ) 721 )
728
729 datamodule.prepare_data()
730 datamodule.setup() 722 datamodule.setup()
731 723
732 if args.num_class_images != 0: 724 if args.num_class_images != 0:
@@ -769,12 +761,14 @@ def main():
769 if torch.cuda.is_available(): 761 if torch.cuda.is_available():
770 torch.cuda.empty_cache() 762 torch.cuda.empty_cache()
771 763
772 train_dataloader = datamodule.train_dataloader() 764 train_dataloaders = datamodule.train_dataloaders
773 val_dataloader = datamodule.val_dataloader() 765 default_train_dataloader = train_dataloaders[0]
766 val_dataloader = datamodule.val_dataloader
774 767
775 # Scheduler and math around the number of training steps. 768 # Scheduler and math around the number of training steps.
776 overrode_max_train_steps = False 769 overrode_max_train_steps = False
777 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 770 num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders)
771 num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps)
778 if args.max_train_steps is None: 772 if args.max_train_steps is None:
779 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 773 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
780 overrode_max_train_steps = True 774 overrode_max_train_steps = True
@@ -811,9 +805,10 @@ def main():
811 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 805 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
812 ) 806 )
813 807
814 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 808 text_encoder, optimizer, val_dataloader, lr_scheduler = accelerator.prepare(
815 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler 809 text_encoder, optimizer, val_dataloader, lr_scheduler
816 ) 810 )
811 train_dataloaders = accelerator.prepare(*train_dataloaders)
817 812
818 # Move vae and unet to device 813 # Move vae and unet to device
819 vae.to(accelerator.device, dtype=weight_dtype) 814 vae.to(accelerator.device, dtype=weight_dtype)
@@ -831,7 +826,8 @@ def main():
831 unet.eval() 826 unet.eval()
832 827
833 # We need to recalculate our total training steps as the size of the training dataloader may have changed. 828 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
834 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 829 num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders)
830 num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps)
835 if overrode_max_train_steps: 831 if overrode_max_train_steps:
836 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 832 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
837 833
@@ -889,7 +885,7 @@ def main():
889 accelerator, 885 accelerator,
890 text_encoder, 886 text_encoder,
891 optimizer, 887 optimizer,
892 train_dataloader, 888 default_train_dataloader,
893 val_dataloader, 889 val_dataloader,
894 loop, 890 loop,
895 on_train=on_train, 891 on_train=on_train,
@@ -968,46 +964,48 @@ def main():
968 text_encoder.train() 964 text_encoder.train()
969 965
970 with on_train(): 966 with on_train():
971 for step, batch in enumerate(train_dataloader): 967 for train_dataloader in train_dataloaders:
972 with accelerator.accumulate(text_encoder): 968 for step, batch in enumerate(train_dataloader):
973 loss, acc, bsz = loop(step, batch) 969 with accelerator.accumulate(text_encoder):
970 loss, acc, bsz = loop(step, batch)
974 971
975 accelerator.backward(loss) 972 accelerator.backward(loss)
976 973
977 optimizer.step() 974 optimizer.step()
978 if not accelerator.optimizer_step_was_skipped: 975 if not accelerator.optimizer_step_was_skipped:
979 lr_scheduler.step() 976 lr_scheduler.step()
980 optimizer.zero_grad(set_to_none=True) 977 optimizer.zero_grad(set_to_none=True)
981 978
982 avg_loss.update(loss.detach_(), bsz) 979 avg_loss.update(loss.detach_(), bsz)
983 avg_acc.update(acc.detach_(), bsz) 980 avg_acc.update(acc.detach_(), bsz)
984 981
985 # Checks if the accelerator has performed an optimization step behind the scenes 982 # Checks if the accelerator has performed an optimization step behind the scenes
986 if accelerator.sync_gradients: 983 if accelerator.sync_gradients:
987 if args.use_ema: 984 if args.use_ema:
988 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) 985 ema_embeddings.step(
986 text_encoder.text_model.embeddings.temp_token_embedding.parameters())
989 987
990 local_progress_bar.update(1) 988 local_progress_bar.update(1)
991 global_progress_bar.update(1) 989 global_progress_bar.update(1)
992 990
993 global_step += 1 991 global_step += 1
994 992
995 logs = { 993 logs = {
996 "train/loss": avg_loss.avg.item(), 994 "train/loss": avg_loss.avg.item(),
997 "train/acc": avg_acc.avg.item(), 995 "train/acc": avg_acc.avg.item(),
998 "train/cur_loss": loss.item(), 996 "train/cur_loss": loss.item(),
999 "train/cur_acc": acc.item(), 997 "train/cur_acc": acc.item(),
1000 "lr": lr_scheduler.get_last_lr()[0], 998 "lr": lr_scheduler.get_last_lr()[0],
1001 } 999 }
1002 if args.use_ema: 1000 if args.use_ema:
1003 logs["ema_decay"] = ema_embeddings.decay 1001 logs["ema_decay"] = ema_embeddings.decay
1004 1002
1005 accelerator.log(logs, step=global_step) 1003 accelerator.log(logs, step=global_step)
1006 1004
1007 local_progress_bar.set_postfix(**logs) 1005 local_progress_bar.set_postfix(**logs)
1008 1006
1009 if global_step >= args.max_train_steps: 1007 if global_step >= args.max_train_steps:
1010 break 1008 break
1011 1009
1012 accelerator.wait_for_everyone() 1010 accelerator.wait_for_everyone()
1013 1011