summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-08 09:43:22 +0100
committerVolpeon <git@volpeon.ink>2023-01-08 09:43:22 +0100
commit5571c4ebcb39813e2bd8585de30c64bb02f9d7fa (patch)
treea073f625eaa49c3cd908aacb3debae23e5badbf7 /train_dreambooth.py
parentCleanup (diff)
downloadtextual-inversion-diff-5571c4ebcb39813e2bd8585de30c64bb02f9d7fa.tar.gz
textual-inversion-diff-5571c4ebcb39813e2bd8585de30c64bb02f9d7fa.tar.bz2
textual-inversion-diff-5571c4ebcb39813e2bd8585de30c64bb02f9d7fa.zip
Improved aspect ratio bucketing
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py100
1 files changed, 45 insertions, 55 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 589af59..42a7d0f 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -134,12 +134,6 @@ def parse_args():
134 help="The directory where class images will be saved.", 134 help="The directory where class images will be saved.",
135 ) 135 )
136 parser.add_argument( 136 parser.add_argument(
137 "--repeats",
138 type=int,
139 default=1,
140 help="How many times to repeat the training data."
141 )
142 parser.add_argument(
143 "--output_dir", 137 "--output_dir",
144 type=str, 138 type=str,
145 default="output/dreambooth", 139 default="output/dreambooth",
@@ -738,7 +732,6 @@ def main():
738 class_subdir=args.class_image_dir, 732 class_subdir=args.class_image_dir,
739 num_class_images=args.num_class_images, 733 num_class_images=args.num_class_images,
740 size=args.resolution, 734 size=args.resolution,
741 repeats=args.repeats,
742 dropout=args.tag_dropout, 735 dropout=args.tag_dropout,
743 template_key=args.train_data_template, 736 template_key=args.train_data_template,
744 valid_set_size=args.valid_set_size, 737 valid_set_size=args.valid_set_size,
@@ -751,7 +744,7 @@ def main():
751 datamodule.prepare_data() 744 datamodule.prepare_data()
752 datamodule.setup() 745 datamodule.setup()
753 746
754 train_dataloaders = datamodule.train_dataloaders 747 train_dataloader = datamodule.train_dataloader
755 val_dataloader = datamodule.val_dataloader 748 val_dataloader = datamodule.val_dataloader
756 749
757 if args.num_class_images != 0: 750 if args.num_class_images != 0:
@@ -770,8 +763,7 @@ def main():
770 763
771 # Scheduler and math around the number of training steps. 764 # Scheduler and math around the number of training steps.
772 overrode_max_train_steps = False 765 overrode_max_train_steps = False
773 num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) 766 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
774 num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps)
775 if args.max_train_steps is None: 767 if args.max_train_steps is None:
776 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 768 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
777 overrode_max_train_steps = True 769 overrode_max_train_steps = True
@@ -820,8 +812,7 @@ def main():
820 ema_unet.to(accelerator.device) 812 ema_unet.to(accelerator.device)
821 813
822 # We need to recalculate our total training steps as the size of the training dataloader may have changed. 814 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
823 num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) 815 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
824 num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps)
825 if overrode_max_train_steps: 816 if overrode_max_train_steps:
826 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 817 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
827 818
@@ -877,7 +868,7 @@ def main():
877 accelerator, 868 accelerator,
878 text_encoder, 869 text_encoder,
879 optimizer, 870 optimizer,
880 train_dataloaders[0], 871 train_dataloader,
881 val_dataloader, 872 val_dataloader,
882 loop, 873 loop,
883 on_train=tokenizer.train, 874 on_train=tokenizer.train,
@@ -960,54 +951,53 @@ def main():
960 text_encoder.requires_grad_(False) 951 text_encoder.requires_grad_(False)
961 952
962 with on_train(): 953 with on_train():
963 for train_dataloader in train_dataloaders: 954 for step, batch in enumerate(train_dataloader):
964 for step, batch in enumerate(train_dataloader): 955 with accelerator.accumulate(unet):
965 with accelerator.accumulate(unet): 956 loss, acc, bsz = loop(step, batch)
966 loss, acc, bsz = loop(step, batch)
967
968 accelerator.backward(loss)
969
970 if accelerator.sync_gradients:
971 params_to_clip = (
972 itertools.chain(unet.parameters(), text_encoder.parameters())
973 if args.train_text_encoder and epoch < args.train_text_encoder_epochs
974 else unet.parameters()
975 )
976 accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
977
978 optimizer.step()
979 if not accelerator.optimizer_step_was_skipped:
980 lr_scheduler.step()
981 if args.use_ema:
982 ema_unet.step(unet.parameters())
983 optimizer.zero_grad(set_to_none=True)
984
985 avg_loss.update(loss.detach_(), bsz)
986 avg_acc.update(acc.detach_(), bsz)
987
988 # Checks if the accelerator has performed an optimization step behind the scenes
989 if accelerator.sync_gradients:
990 local_progress_bar.update(1)
991 global_progress_bar.update(1)
992 957
993 global_step += 1 958 accelerator.backward(loss)
994 959
995 logs = { 960 if accelerator.sync_gradients:
996 "train/loss": avg_loss.avg.item(), 961 params_to_clip = (
997 "train/acc": avg_acc.avg.item(), 962 itertools.chain(unet.parameters(), text_encoder.parameters())
998 "train/cur_loss": loss.item(), 963 if args.train_text_encoder and epoch < args.train_text_encoder_epochs
999 "train/cur_acc": acc.item(), 964 else unet.parameters()
1000 "lr": lr_scheduler.get_last_lr()[0] 965 )
1001 } 966 accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
967
968 optimizer.step()
969 if not accelerator.optimizer_step_was_skipped:
970 lr_scheduler.step()
1002 if args.use_ema: 971 if args.use_ema:
1003 logs["ema_decay"] = 1 - ema_unet.decay 972 ema_unet.step(unet.parameters())
973 optimizer.zero_grad(set_to_none=True)
1004 974
1005 accelerator.log(logs, step=global_step) 975 avg_loss.update(loss.detach_(), bsz)
976 avg_acc.update(acc.detach_(), bsz)
1006 977
1007 local_progress_bar.set_postfix(**logs) 978 # Checks if the accelerator has performed an optimization step behind the scenes
979 if accelerator.sync_gradients:
980 local_progress_bar.update(1)
981 global_progress_bar.update(1)
982
983 global_step += 1
984
985 logs = {
986 "train/loss": avg_loss.avg.item(),
987 "train/acc": avg_acc.avg.item(),
988 "train/cur_loss": loss.item(),
989 "train/cur_acc": acc.item(),
990 "lr": lr_scheduler.get_last_lr()[0]
991 }
992 if args.use_ema:
993 logs["ema_decay"] = 1 - ema_unet.decay
994
995 accelerator.log(logs, step=global_step)
996
997 local_progress_bar.set_postfix(**logs)
1008 998
1009 if global_step >= args.max_train_steps: 999 if global_step >= args.max_train_steps:
1010 break 1000 break
1011 1001
1012 accelerator.wait_for_everyone() 1002 accelerator.wait_for_everyone()
1013 1003