summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py85
1 files changed, 37 insertions, 48 deletions
diff --git a/train_ti.py b/train_ti.py
index b4b602b..727b591 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -107,12 +107,6 @@ def parse_args():
107 help="Exclude all items with a listed collection.", 107 help="Exclude all items with a listed collection.",
108 ) 108 )
109 parser.add_argument( 109 parser.add_argument(
110 "--repeats",
111 type=int,
112 default=1,
113 help="How many times to repeat the training data."
114 )
115 parser.add_argument(
116 "--output_dir", 110 "--output_dir",
117 type=str, 111 type=str,
118 default="output/text-inversion", 112 default="output/text-inversion",
@@ -722,7 +716,6 @@ def main():
722 size=args.resolution, 716 size=args.resolution,
723 num_aspect_ratio_buckets=args.num_aspect_ratio_buckets, 717 num_aspect_ratio_buckets=args.num_aspect_ratio_buckets,
724 progressive_aspect_ratio_buckets=args.progressive_aspect_ratio_buckets, 718 progressive_aspect_ratio_buckets=args.progressive_aspect_ratio_buckets,
725 repeats=args.repeats,
726 dropout=args.tag_dropout, 719 dropout=args.tag_dropout,
727 template_key=args.train_data_template, 720 template_key=args.train_data_template,
728 valid_set_size=args.valid_set_size, 721 valid_set_size=args.valid_set_size,
@@ -733,7 +726,7 @@ def main():
733 ) 726 )
734 datamodule.setup() 727 datamodule.setup()
735 728
736 train_dataloaders = datamodule.train_dataloaders 729 train_dataloader = datamodule.train_dataloader
737 val_dataloader = datamodule.val_dataloader 730 val_dataloader = datamodule.val_dataloader
738 731
739 if args.num_class_images != 0: 732 if args.num_class_images != 0:
@@ -752,8 +745,7 @@ def main():
752 745
753 # Scheduler and math around the number of training steps. 746 # Scheduler and math around the number of training steps.
754 overrode_max_train_steps = False 747 overrode_max_train_steps = False
755 num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) 748 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
756 num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps)
757 if args.max_train_steps is None: 749 if args.max_train_steps is None:
758 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 750 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
759 overrode_max_train_steps = True 751 overrode_max_train_steps = True
@@ -790,10 +782,9 @@ def main():
790 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 782 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
791 ) 783 )
792 784
793 text_encoder, optimizer, val_dataloader, lr_scheduler = accelerator.prepare( 785 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
794 text_encoder, optimizer, val_dataloader, lr_scheduler 786 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
795 ) 787 )
796 train_dataloaders = accelerator.prepare(*train_dataloaders)
797 788
798 # Move vae and unet to device 789 # Move vae and unet to device
799 vae.to(accelerator.device, dtype=weight_dtype) 790 vae.to(accelerator.device, dtype=weight_dtype)
@@ -811,8 +802,7 @@ def main():
811 unet.eval() 802 unet.eval()
812 803
813 # We need to recalculate our total training steps as the size of the training dataloader may have changed. 804 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
814 num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) 805 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
815 num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps)
816 if overrode_max_train_steps: 806 if overrode_max_train_steps:
817 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 807 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
818 808
@@ -870,7 +860,7 @@ def main():
870 accelerator, 860 accelerator,
871 text_encoder, 861 text_encoder,
872 optimizer, 862 optimizer,
873 train_dataloaders[0], 863 train_dataloader,
874 val_dataloader, 864 val_dataloader,
875 loop, 865 loop,
876 on_train=on_train, 866 on_train=on_train,
@@ -949,48 +939,47 @@ def main():
949 text_encoder.train() 939 text_encoder.train()
950 940
951 with on_train(): 941 with on_train():
952 for train_dataloader in train_dataloaders: 942 for step, batch in enumerate(train_dataloader):
953 for step, batch in enumerate(train_dataloader): 943 with accelerator.accumulate(text_encoder):
954 with accelerator.accumulate(text_encoder): 944 loss, acc, bsz = loop(step, batch)
955 loss, acc, bsz = loop(step, batch)
956 945
957 accelerator.backward(loss) 946 accelerator.backward(loss)
958 947
959 optimizer.step() 948 optimizer.step()
960 if not accelerator.optimizer_step_was_skipped: 949 if not accelerator.optimizer_step_was_skipped:
961 lr_scheduler.step() 950 lr_scheduler.step()
962 optimizer.zero_grad(set_to_none=True) 951 optimizer.zero_grad(set_to_none=True)
963 952
964 avg_loss.update(loss.detach_(), bsz) 953 avg_loss.update(loss.detach_(), bsz)
965 avg_acc.update(acc.detach_(), bsz) 954 avg_acc.update(acc.detach_(), bsz)
966 955
967 # Checks if the accelerator has performed an optimization step behind the scenes 956 # Checks if the accelerator has performed an optimization step behind the scenes
968 if accelerator.sync_gradients: 957 if accelerator.sync_gradients:
969 if args.use_ema: 958 if args.use_ema:
970 ema_embeddings.step( 959 ema_embeddings.step(
971 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) 960 text_encoder.text_model.embeddings.temp_token_embedding.parameters())
972 961
973 local_progress_bar.update(1) 962 local_progress_bar.update(1)
974 global_progress_bar.update(1) 963 global_progress_bar.update(1)
975 964
976 global_step += 1 965 global_step += 1
977 966
978 logs = { 967 logs = {
979 "train/loss": avg_loss.avg.item(), 968 "train/loss": avg_loss.avg.item(),
980 "train/acc": avg_acc.avg.item(), 969 "train/acc": avg_acc.avg.item(),
981 "train/cur_loss": loss.item(), 970 "train/cur_loss": loss.item(),
982 "train/cur_acc": acc.item(), 971 "train/cur_acc": acc.item(),
983 "lr": lr_scheduler.get_last_lr()[0], 972 "lr": lr_scheduler.get_last_lr()[0],
984 } 973 }
985 if args.use_ema: 974 if args.use_ema:
986 logs["ema_decay"] = ema_embeddings.decay 975 logs["ema_decay"] = ema_embeddings.decay
987 976
988 accelerator.log(logs, step=global_step) 977 accelerator.log(logs, step=global_step)
989 978
990 local_progress_bar.set_postfix(**logs) 979 local_progress_bar.set_postfix(**logs)
991 980
992 if global_step >= args.max_train_steps: 981 if global_step >= args.max_train_steps:
993 break 982 break
994 983
995 accelerator.wait_for_everyone() 984 accelerator.wait_for_everyone()
996 985