summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-11 12:59:13 +0100
committerVolpeon <git@volpeon.ink>2022-12-11 12:59:13 +0100
commit8f2b8e8d309470babd9b853fde8f0a081366deae (patch)
tree1374e791705e31fa77fefeb5001aad204cdf3224 /dreambooth.py
parentSupport attention_mask of text encoder (diff)
downloadtextual-inversion-diff-8f2b8e8d309470babd9b853fde8f0a081366deae.tar.gz
textual-inversion-diff-8f2b8e8d309470babd9b853fde8f0a081366deae.tar.bz2
textual-inversion-diff-8f2b8e8d309470babd9b853fde8f0a081366deae.zip
Training improvements such as tag dropout
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py91
1 files changed, 58 insertions, 33 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 1ef5156..1d6735f 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -89,6 +89,17 @@ def parse_args():
89 help="Whether to train the whole text encoder." 89 help="Whether to train the whole text encoder."
90 ) 90 )
91 parser.add_argument( 91 parser.add_argument(
92 "--train_text_encoder_epochs",
93 default=999999,
94 help="Number of epochs the text encoder will be trained."
95 )
96 parser.add_argument(
97 "--tag_dropout",
98 type=float,
99 default=0.1,
100 help="Tag dropout probability.",
101 )
102 parser.add_argument(
92 "--num_class_images", 103 "--num_class_images",
93 type=int, 104 type=int,
94 default=400, 105 default=400,
@@ -185,9 +196,9 @@ def parse_args():
185 ), 196 ),
186 ) 197 )
187 parser.add_argument( 198 parser.add_argument(
188 "--lr_warmup_steps", 199 "--lr_warmup_epochs",
189 type=int, 200 type=int,
190 default=500, 201 default=20,
191 help="Number of steps for the warmup in the lr scheduler." 202 help="Number of steps for the warmup in the lr scheduler."
192 ) 203 )
193 parser.add_argument( 204 parser.add_argument(
@@ -377,6 +388,20 @@ def make_grid(images, rows, cols):
377 return grid 388 return grid
378 389
379 390
391class AverageMeter:
392 def __init__(self, name=None):
393 self.name = name
394 self.reset()
395
396 def reset(self):
397 self.sum = self.count = self.avg = 0
398
399 def update(self, val, n=1):
400 self.sum += val * n
401 self.count += n
402 self.avg = self.sum / self.count
403
404
380class Checkpointer: 405class Checkpointer:
381 def __init__( 406 def __init__(
382 self, 407 self,
@@ -744,6 +769,7 @@ def main():
744 num_class_images=args.num_class_images, 769 num_class_images=args.num_class_images,
745 size=args.resolution, 770 size=args.resolution,
746 repeats=args.repeats, 771 repeats=args.repeats,
772 dropout=args.tag_dropout,
747 center_crop=args.center_crop, 773 center_crop=args.center_crop,
748 valid_set_size=args.valid_set_size, 774 valid_set_size=args.valid_set_size,
749 num_workers=args.dataloader_num_workers, 775 num_workers=args.dataloader_num_workers,
@@ -802,6 +828,8 @@ def main():
802 overrode_max_train_steps = True 828 overrode_max_train_steps = True
803 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 829 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
804 830
831 warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps
832
805 if args.lr_scheduler == "one_cycle": 833 if args.lr_scheduler == "one_cycle":
806 lr_scheduler = get_one_cycle_schedule( 834 lr_scheduler = get_one_cycle_schedule(
807 optimizer=optimizer, 835 optimizer=optimizer,
@@ -810,16 +838,16 @@ def main():
810 elif args.lr_scheduler == "cosine_with_restarts": 838 elif args.lr_scheduler == "cosine_with_restarts":
811 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 839 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
812 optimizer=optimizer, 840 optimizer=optimizer,
813 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 841 num_warmup_steps=warmup_steps,
814 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 842 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
815 num_cycles=args.lr_cycles or math.ceil(math.sqrt( 843 num_cycles=args.lr_cycles or math.ceil(math.sqrt(
816 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), 844 ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))),
817 ) 845 )
818 else: 846 else:
819 lr_scheduler = get_scheduler( 847 lr_scheduler = get_scheduler(
820 args.lr_scheduler, 848 args.lr_scheduler,
821 optimizer=optimizer, 849 optimizer=optimizer,
822 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 850 num_warmup_steps=warmup_steps,
823 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 851 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
824 ) 852 )
825 853
@@ -863,11 +891,11 @@ def main():
863 891
864 global_step = 0 892 global_step = 0
865 893
866 total_loss = 0.0 894 avg_loss = AverageMeter()
867 total_acc = 0.0 895 avg_acc = AverageMeter()
868 896
869 total_loss_val = 0.0 897 avg_loss_val = AverageMeter()
870 total_acc_val = 0.0 898 avg_acc_val = AverageMeter()
871 899
872 max_acc_val = 0.0 900 max_acc_val = 0.0
873 901
@@ -913,7 +941,11 @@ def main():
913 local_progress_bar.reset() 941 local_progress_bar.reset()
914 942
915 unet.train() 943 unet.train()
916 text_encoder.train() 944
945 if epoch < args.train_text_encoder_epochs:
946 text_encoder.train()
947 elif epoch == args.train_text_encoder_epochs:
948 freeze_params(text_encoder.parameters())
917 949
918 sample_checkpoint = False 950 sample_checkpoint = False
919 951
@@ -980,7 +1012,7 @@ def main():
980 if accelerator.sync_gradients: 1012 if accelerator.sync_gradients:
981 params_to_clip = ( 1013 params_to_clip = (
982 itertools.chain(unet.parameters(), text_encoder.parameters()) 1014 itertools.chain(unet.parameters(), text_encoder.parameters())
983 if args.train_text_encoder 1015 if args.train_text_encoder and epoch < args.train_text_encoder_epochs
984 else unet.parameters() 1016 else unet.parameters()
985 ) 1017 )
986 accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 1018 accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
@@ -992,11 +1024,10 @@ def main():
992 ema_unet.step(unet) 1024 ema_unet.step(unet)
993 optimizer.zero_grad(set_to_none=True) 1025 optimizer.zero_grad(set_to_none=True)
994 1026
995 acc = (model_pred == latents).float() 1027 acc = (model_pred == latents).float().mean()
996 acc = acc.mean()
997 1028
998 total_loss += loss.item() 1029 avg_loss.update(loss.detach_(), bsz)
999 total_acc += acc.item() 1030 avg_acc.update(acc.detach_(), bsz)
1000 1031
1001 # Checks if the accelerator has performed an optimization step behind the scenes 1032 # Checks if the accelerator has performed an optimization step behind the scenes
1002 if accelerator.sync_gradients: 1033 if accelerator.sync_gradients:
@@ -1013,8 +1044,8 @@ def main():
1013 sample_checkpoint = True 1044 sample_checkpoint = True
1014 1045
1015 logs = { 1046 logs = {
1016 "train/loss": total_loss / global_step if global_step != 0 else 0, 1047 "train/loss": avg_loss.avg.item(),
1017 "train/acc": total_acc / global_step if global_step != 0 else 0, 1048 "train/acc": avg_acc.avg.item(),
1018 "train/cur_loss": loss.item(), 1049 "train/cur_loss": loss.item(),
1019 "train/cur_acc": acc.item(), 1050 "train/cur_acc": acc.item(),
1020 "lr/unet": lr_scheduler.get_last_lr()[0], 1051 "lr/unet": lr_scheduler.get_last_lr()[0],
@@ -1064,41 +1095,35 @@ def main():
1064 1095
1065 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 1096 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1066 1097
1067 acc = (model_pred == latents).float() 1098 acc = (model_pred == latents).float().mean()
1068 acc = acc.mean()
1069 1099
1070 total_loss_val += loss.item() 1100 avg_loss_val.update(loss.detach_(), bsz)
1071 total_acc_val += acc.item() 1101 avg_acc_val.update(acc.detach_(), bsz)
1072 1102
1073 if accelerator.sync_gradients: 1103 if accelerator.sync_gradients:
1074 local_progress_bar.update(1) 1104 local_progress_bar.update(1)
1075 global_progress_bar.update(1) 1105 global_progress_bar.update(1)
1076 1106
1077 logs = { 1107 logs = {
1078 "val/loss": total_loss_val / global_step, 1108 "val/loss": avg_loss_val.avg.item(),
1079 "val/acc": total_acc_val / global_step, 1109 "val/acc": avg_acc_val.avg.item(),
1080 "val/cur_loss": loss.item(), 1110 "val/cur_loss": loss.item(),
1081 "val/cur_acc": acc.item(), 1111 "val/cur_acc": acc.item(),
1082 } 1112 }
1083 local_progress_bar.set_postfix(**logs) 1113 local_progress_bar.set_postfix(**logs)
1084 1114
1085 val_step = (epoch + 1) * len(val_dataloader)
1086 avg_acc_val = total_acc_val / val_step
1087 avg_loss_val = total_loss_val / val_step
1088
1089 accelerator.log({ 1115 accelerator.log({
1090 "val/loss": avg_loss_val, 1116 "val/loss": avg_loss_val.avg.item(),
1091 "val/acc": avg_acc_val, 1117 "val/acc": avg_acc_val.avg.item(),
1092 }, step=global_step) 1118 }, step=global_step)
1093 1119
1094 local_progress_bar.clear() 1120 local_progress_bar.clear()
1095 global_progress_bar.clear() 1121 global_progress_bar.clear()
1096 1122
1097 if avg_acc_val > max_acc_val: 1123 if avg_acc_val.avg.item() > max_acc_val:
1098 accelerator.print( 1124 accelerator.print(
1099 f"Global step {global_step}: Validation loss reached new maximum: {max_acc_val:.2e} -> {avg_acc_val:.2e}") 1125 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
1100 checkpointer.save_embedding(global_step, "milestone") 1126 max_acc_val = avg_acc_val.avg.item()
1101 max_acc_val = avg_acc_val
1102 1127
1103 if sample_checkpoint and accelerator.is_main_process: 1128 if sample_checkpoint and accelerator.is_main_process:
1104 checkpointer.save_samples(global_step, args.sample_steps) 1129 checkpointer.save_samples(global_step, args.sample_steps)