summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py13
-rw-r--r--dreambooth.py91
2 files changed, 68 insertions, 36 deletions
diff --git a/data/csv.py b/data/csv.py
index 23b5299..9125212 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -16,14 +16,17 @@ def prepare_prompt(prompt: Union[str, Dict[str, str]]):
16 return {"content": prompt} if isinstance(prompt, str) else prompt 16 return {"content": prompt} if isinstance(prompt, str) else prompt
17 17
18 18
19def shuffle_prompt(prompt: str): 19def shuffle_prompt(prompt: str, dropout: float = 0):
20 def handle_block(block: str): 20 def handle_block(block: str):
21 words = block.split(", ") 21 words = block.split(", ")
22 words = [w for w in words if w != ""]
23 if dropout != 0:
24 words = [w for w in words if np.random.random() > dropout]
22 np.random.shuffle(words) 25 np.random.shuffle(words)
23 return ", ".join(words) 26 return ", ".join(words)
24 27
25 prompt = prompt.split(". ") 28 prompt = prompt.split(". ")
26 prompt = [handle_block(b) for b in prompt] 29 prompt = [handle_block(b) for b in prompt if b != ""]
27 np.random.shuffle(prompt) 30 np.random.shuffle(prompt)
28 prompt = ". ".join(prompt) 31 prompt = ". ".join(prompt)
29 return prompt 32 return prompt
@@ -48,6 +51,7 @@ class CSVDataModule(pl.LightningDataModule):
48 num_class_images: int = 100, 51 num_class_images: int = 100,
49 size: int = 512, 52 size: int = 512,
50 repeats: int = 1, 53 repeats: int = 1,
54 dropout: float = 0,
51 interpolation: str = "bicubic", 55 interpolation: str = "bicubic",
52 center_crop: bool = False, 56 center_crop: bool = False,
53 valid_set_size: Optional[int] = None, 57 valid_set_size: Optional[int] = None,
@@ -72,6 +76,7 @@ class CSVDataModule(pl.LightningDataModule):
72 self.class_identifier = class_identifier 76 self.class_identifier = class_identifier
73 self.size = size 77 self.size = size
74 self.repeats = repeats 78 self.repeats = repeats
79 self.dropout = dropout
75 self.center_crop = center_crop 80 self.center_crop = center_crop
76 self.interpolation = interpolation 81 self.interpolation = interpolation
77 self.valid_set_size = valid_set_size 82 self.valid_set_size = valid_set_size
@@ -123,7 +128,7 @@ class CSVDataModule(pl.LightningDataModule):
123 instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, 128 instance_identifier=self.instance_identifier, class_identifier=self.class_identifier,
124 num_class_images=self.num_class_images, 129 num_class_images=self.num_class_images,
125 size=self.size, interpolation=self.interpolation, 130 size=self.size, interpolation=self.interpolation,
126 center_crop=self.center_crop, repeats=self.repeats) 131 center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout)
127 val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, 132 val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size,
128 instance_identifier=self.instance_identifier, 133 instance_identifier=self.instance_identifier,
129 size=self.size, interpolation=self.interpolation, 134 size=self.size, interpolation=self.interpolation,
@@ -153,6 +158,7 @@ class CSVDataset(Dataset):
153 num_class_images: int = 0, 158 num_class_images: int = 0,
154 size: int = 512, 159 size: int = 512,
155 repeats: int = 1, 160 repeats: int = 1,
161 dropout: float = 0,
156 interpolation: str = "bicubic", 162 interpolation: str = "bicubic",
157 center_crop: bool = False, 163 center_crop: bool = False,
158 ): 164 ):
@@ -163,6 +169,7 @@ class CSVDataset(Dataset):
163 self.instance_identifier = instance_identifier 169 self.instance_identifier = instance_identifier
164 self.class_identifier = class_identifier 170 self.class_identifier = class_identifier
165 self.num_class_images = num_class_images 171 self.num_class_images = num_class_images
172 self.dropout = dropout
166 self.image_cache = {} 173 self.image_cache = {}
167 174
168 self.num_instance_images = len(self.data) 175 self.num_instance_images = len(self.data)
diff --git a/dreambooth.py b/dreambooth.py
index 1ef5156..1d6735f 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -89,6 +89,17 @@ def parse_args():
89 help="Whether to train the whole text encoder." 89 help="Whether to train the whole text encoder."
90 ) 90 )
91 parser.add_argument( 91 parser.add_argument(
92 "--train_text_encoder_epochs",
93 default=999999,
94 help="Number of epochs the text encoder will be trained."
95 )
96 parser.add_argument(
97 "--tag_dropout",
98 type=float,
99 default=0.1,
100 help="Tag dropout probability.",
101 )
102 parser.add_argument(
92 "--num_class_images", 103 "--num_class_images",
93 type=int, 104 type=int,
94 default=400, 105 default=400,
@@ -185,9 +196,9 @@ def parse_args():
185 ), 196 ),
186 ) 197 )
187 parser.add_argument( 198 parser.add_argument(
188 "--lr_warmup_steps", 199 "--lr_warmup_epochs",
189 type=int, 200 type=int,
190 default=500, 201 default=20,
191 help="Number of steps for the warmup in the lr scheduler." 202 help="Number of steps for the warmup in the lr scheduler."
192 ) 203 )
193 parser.add_argument( 204 parser.add_argument(
@@ -377,6 +388,20 @@ def make_grid(images, rows, cols):
377 return grid 388 return grid
378 389
379 390
391class AverageMeter:
392 def __init__(self, name=None):
393 self.name = name
394 self.reset()
395
396 def reset(self):
397 self.sum = self.count = self.avg = 0
398
399 def update(self, val, n=1):
400 self.sum += val * n
401 self.count += n
402 self.avg = self.sum / self.count
403
404
380class Checkpointer: 405class Checkpointer:
381 def __init__( 406 def __init__(
382 self, 407 self,
@@ -744,6 +769,7 @@ def main():
744 num_class_images=args.num_class_images, 769 num_class_images=args.num_class_images,
745 size=args.resolution, 770 size=args.resolution,
746 repeats=args.repeats, 771 repeats=args.repeats,
772 dropout=args.tag_dropout,
747 center_crop=args.center_crop, 773 center_crop=args.center_crop,
748 valid_set_size=args.valid_set_size, 774 valid_set_size=args.valid_set_size,
749 num_workers=args.dataloader_num_workers, 775 num_workers=args.dataloader_num_workers,
@@ -802,6 +828,8 @@ def main():
802 overrode_max_train_steps = True 828 overrode_max_train_steps = True
803 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 829 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
804 830
831 warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps
832
805 if args.lr_scheduler == "one_cycle": 833 if args.lr_scheduler == "one_cycle":
806 lr_scheduler = get_one_cycle_schedule( 834 lr_scheduler = get_one_cycle_schedule(
807 optimizer=optimizer, 835 optimizer=optimizer,
@@ -810,16 +838,16 @@ def main():
810 elif args.lr_scheduler == "cosine_with_restarts": 838 elif args.lr_scheduler == "cosine_with_restarts":
811 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 839 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
812 optimizer=optimizer, 840 optimizer=optimizer,
813 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 841 num_warmup_steps=warmup_steps,
814 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 842 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
815 num_cycles=args.lr_cycles or math.ceil(math.sqrt( 843 num_cycles=args.lr_cycles or math.ceil(math.sqrt(
816 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))), 844 ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))),
817 ) 845 )
818 else: 846 else:
819 lr_scheduler = get_scheduler( 847 lr_scheduler = get_scheduler(
820 args.lr_scheduler, 848 args.lr_scheduler,
821 optimizer=optimizer, 849 optimizer=optimizer,
822 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 850 num_warmup_steps=warmup_steps,
823 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 851 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
824 ) 852 )
825 853
@@ -863,11 +891,11 @@ def main():
863 891
864 global_step = 0 892 global_step = 0
865 893
866 total_loss = 0.0 894 avg_loss = AverageMeter()
867 total_acc = 0.0 895 avg_acc = AverageMeter()
868 896
869 total_loss_val = 0.0 897 avg_loss_val = AverageMeter()
870 total_acc_val = 0.0 898 avg_acc_val = AverageMeter()
871 899
872 max_acc_val = 0.0 900 max_acc_val = 0.0
873 901
@@ -913,7 +941,11 @@ def main():
913 local_progress_bar.reset() 941 local_progress_bar.reset()
914 942
915 unet.train() 943 unet.train()
916 text_encoder.train() 944
945 if epoch < args.train_text_encoder_epochs:
946 text_encoder.train()
947 elif epoch == args.train_text_encoder_epochs:
948 freeze_params(text_encoder.parameters())
917 949
918 sample_checkpoint = False 950 sample_checkpoint = False
919 951
@@ -980,7 +1012,7 @@ def main():
980 if accelerator.sync_gradients: 1012 if accelerator.sync_gradients:
981 params_to_clip = ( 1013 params_to_clip = (
982 itertools.chain(unet.parameters(), text_encoder.parameters()) 1014 itertools.chain(unet.parameters(), text_encoder.parameters())
983 if args.train_text_encoder 1015 if args.train_text_encoder and epoch < args.train_text_encoder_epochs
984 else unet.parameters() 1016 else unet.parameters()
985 ) 1017 )
986 accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 1018 accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
@@ -992,11 +1024,10 @@ def main():
992 ema_unet.step(unet) 1024 ema_unet.step(unet)
993 optimizer.zero_grad(set_to_none=True) 1025 optimizer.zero_grad(set_to_none=True)
994 1026
995 acc = (model_pred == latents).float() 1027 acc = (model_pred == latents).float().mean()
996 acc = acc.mean()
997 1028
998 total_loss += loss.item() 1029 avg_loss.update(loss.detach_(), bsz)
999 total_acc += acc.item() 1030 avg_acc.update(acc.detach_(), bsz)
1000 1031
1001 # Checks if the accelerator has performed an optimization step behind the scenes 1032 # Checks if the accelerator has performed an optimization step behind the scenes
1002 if accelerator.sync_gradients: 1033 if accelerator.sync_gradients:
@@ -1013,8 +1044,8 @@ def main():
1013 sample_checkpoint = True 1044 sample_checkpoint = True
1014 1045
1015 logs = { 1046 logs = {
1016 "train/loss": total_loss / global_step if global_step != 0 else 0, 1047 "train/loss": avg_loss.avg.item(),
1017 "train/acc": total_acc / global_step if global_step != 0 else 0, 1048 "train/acc": avg_acc.avg.item(),
1018 "train/cur_loss": loss.item(), 1049 "train/cur_loss": loss.item(),
1019 "train/cur_acc": acc.item(), 1050 "train/cur_acc": acc.item(),
1020 "lr/unet": lr_scheduler.get_last_lr()[0], 1051 "lr/unet": lr_scheduler.get_last_lr()[0],
@@ -1064,41 +1095,35 @@ def main():
1064 1095
1065 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 1096 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1066 1097
1067 acc = (model_pred == latents).float() 1098 acc = (model_pred == latents).float().mean()
1068 acc = acc.mean()
1069 1099
1070 total_loss_val += loss.item() 1100 avg_loss_val.update(loss.detach_(), bsz)
1071 total_acc_val += acc.item() 1101 avg_acc_val.update(acc.detach_(), bsz)
1072 1102
1073 if accelerator.sync_gradients: 1103 if accelerator.sync_gradients:
1074 local_progress_bar.update(1) 1104 local_progress_bar.update(1)
1075 global_progress_bar.update(1) 1105 global_progress_bar.update(1)
1076 1106
1077 logs = { 1107 logs = {
1078 "val/loss": total_loss_val / global_step, 1108 "val/loss": avg_loss_val.avg.item(),
1079 "val/acc": total_acc_val / global_step, 1109 "val/acc": avg_acc_val.avg.item(),
1080 "val/cur_loss": loss.item(), 1110 "val/cur_loss": loss.item(),
1081 "val/cur_acc": acc.item(), 1111 "val/cur_acc": acc.item(),
1082 } 1112 }
1083 local_progress_bar.set_postfix(**logs) 1113 local_progress_bar.set_postfix(**logs)
1084 1114
1085 val_step = (epoch + 1) * len(val_dataloader)
1086 avg_acc_val = total_acc_val / val_step
1087 avg_loss_val = total_loss_val / val_step
1088
1089 accelerator.log({ 1115 accelerator.log({
1090 "val/loss": avg_loss_val, 1116 "val/loss": avg_loss_val.avg.item(),
1091 "val/acc": avg_acc_val, 1117 "val/acc": avg_acc_val.avg.item(),
1092 }, step=global_step) 1118 }, step=global_step)
1093 1119
1094 local_progress_bar.clear() 1120 local_progress_bar.clear()
1095 global_progress_bar.clear() 1121 global_progress_bar.clear()
1096 1122
1097 if avg_acc_val > max_acc_val: 1123 if avg_acc_val.avg.item() > max_acc_val:
1098 accelerator.print( 1124 accelerator.print(
1099 f"Global step {global_step}: Validation loss reached new maximum: {max_acc_val:.2e} -> {avg_acc_val:.2e}") 1125 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
1100 checkpointer.save_embedding(global_step, "milestone") 1126 max_acc_val = avg_acc_val.avg.item()
1101 max_acc_val = avg_acc_val
1102 1127
1103 if sample_checkpoint and accelerator.is_main_process: 1128 if sample_checkpoint and accelerator.is_main_process:
1104 checkpointer.save_samples(global_step, args.sample_steps) 1129 checkpointer.save_samples(global_step, args.sample_steps)